In [9]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.optim import Adam
from torch import nn
import torch


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return correct / len(test_loader.dataset)

@torch.inference_mode()
def predict(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    prediction = torch.empty(0, device=device)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            output = model(x)
            pred = torch.argmax(output, dim=1)
            prediction = torch.cat((prediction, pred))
    return prediction


@torch.inference_mode()
def predict_tta(model: nn.Module, loader: DataLoader, device: torch.device, iterations: int = 2):
    model.eval()
    preds = []
    for i in range(iterations):
        prediction = []
        for x, y in loader:
            x,y = x.to(device), y.to(device)
            output = model(x)
            prediction.append(output)
        preds.append(torch.cat(prediction))
    preds = torch.stack(preds).mean(dim=0)
    preds = torch.argmax(preds, dim=1)  

    return preds


In [23]:
def create_simple_conv_cifar() -> nn.Sequential:
    model = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),  # 32 x 32 x 16
        nn.ReLU(),
        nn.MaxPool2d(2),  # 16 x 16 x 16
        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),  # 16 x 16 x 32
        nn.ReLU(),
        nn.MaxPool2d(2),  # 8 x 8 x 32
        nn.Flatten(),
        nn.Linear(8 * 8 * 32, 1024),
        nn.ReLU(),
        nn.Linear(1024, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    return model

def create_modified_conv_cifar() -> nn.Sequential:
    model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),  # 32 x 32 x 32
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),  # 32 x 32 x 32
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 16 x 16 x 32
            nn.Dropout2d(p=0.2),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),  # 16 x 16 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),  # 16 x 16 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 8 x 8 x 32
            nn.Dropout2d(p=0.2),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),  # 8 x 8 x 128
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),  # 8 x 8 x 128
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 4 x 4 x 128
            nn.Dropout2d(p=0.2),

            nn.Flatten(),

            nn.Linear(4 * 4 * 128, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.3),
            nn.ReLU(),
            nn.Linear(512, 10)
    )
    return model


def create_modified2_conv_cifar() -> nn.Sequential:
    model = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),  # 32 x 32 x 32
        nn.BatchNorm2d(32),
        nn.ReLU(),

        nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),  # 32 x 32 x 32
        nn.BatchNorm2d(32),
        nn.ReLU(),

        nn.MaxPool2d(2),  # 16 x 16 x 32
        nn.Dropout2d(p=0.2),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),  # 16 x 16 x 64
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),  # 16 x 16 x 64
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.MaxPool2d(2),  # 8 x 8 x 32
        nn.Dropout2d(p=0.2),

        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),  # 8 x 8 x 128
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),  # 8 x 8 x 128
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.MaxPool2d(2),  # 4 x 4 x 128
        nn.Dropout2d(p=0.2),

        nn.Flatten(),

        nn.Linear(4 * 4 * 128, 512),
        nn.BatchNorm1d(512),
        nn.Dropout(p=0.3),
        nn.ReLU(),
        nn.Linear(512, 10)
    )
    return model


class IDABlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(IDABlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out + x)  # This is a very simplified form of IDA
        return out
    
def create_modified3_conv_cifar() -> nn.Sequential:
    model = nn.Sequential(
        IDABlock(3, 32),
        nn.MaxPool2d(2),
        nn.Dropout2d(p=0.2),
        IDABlock(32, 64),
        nn.MaxPool2d(2),
        nn.Dropout2d(p=0.2),
        IDABlock(64, 128),
        nn.MaxPool2d(2),
        nn.Dropout2d(p=0.2),
        nn.Flatten(),
        nn.Linear(4 * 4 * 128, 512),
        nn.BatchNorm1d(512),
        nn.Dropout(p=0.3),
        nn.ReLU(),
        nn.Linear(512, 10)
    )
    return model

In [25]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, 5, padding=2)
        self.conv2 = nn.Conv2d(128, 128, 5, padding=2)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.bn_conv1 = nn.BatchNorm2d(128)
        self.bn_conv2 = nn.BatchNorm2d(128)
        self.bn_conv3 = nn.BatchNorm2d(256)
        self.bn_conv4 = nn.BatchNorm2d(256)
        self.bn_dense1 = nn.BatchNorm1d(1024)
        self.bn_dense2 = nn.BatchNorm1d(512)
        self.dropout_conv = nn.Dropout2d(p=0.25)
        self.dropout = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 10)

    def conv_layers(self, x):
        out = F.relu(self.bn_conv1(self.conv1(x)))
        out = F.relu(self.bn_conv2(self.conv2(out)))
        out = self.pool(out)
        out = self.dropout_conv(out)
        out = F.relu(self.bn_conv3(self.conv3(out)))
        out = F.relu(self.bn_conv4(self.conv4(out)))
        out = self.pool(out)
        out = self.dropout_conv(out)
        return out

    def dense_layers(self, x):
        out = F.relu(self.bn_dense1(self.fc1(x)))
        out = self.dropout(out)
        out = F.relu(self.bn_dense2(self.fc2(out)))
        out = self.dropout(out)
        out = self.fc3(out)
        return out

    def forward(self, x):
        out = self.conv_layers(x)
        out = out.view(-1, 256 * 8 * 8)
        out = self.dense_layers(out)
        return out

def create_modified3_conv_cifar() -> nn.Module:
    return Net()

In [None]:
import torch.nn as nn
import torch.nn.functional as F

def create_modified3_conv_cifar():
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 128, 5, padding=2)
            self.conv2 = nn.Conv2d(128, 128, 5, padding=2)
            self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
            self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.bn_conv1 = nn.BatchNorm2d(128)
            self.bn_conv2 = nn.BatchNorm2d(128)
            self.bn_conv3 = nn.BatchNorm2d(256)
            self.bn_conv4 = nn.BatchNorm2d(256)
            self.bn_dense1 = nn.BatchNorm1d(1024)
            self.bn_dense2 = nn.BatchNorm1d(512)
            self.dropout_conv = nn.Dropout2d(p=0.25)
            self.dropout = nn.Dropout(p=0.5)
            self.fc1 = nn.Linear(256 * 8 * 8, 1024)
            self.fc2 = nn.Linear(1024, 512)
            self.fc3 = nn.Linear(512, 10)

        def conv_layers(self, x):
            out = F.relu(self.bn_conv1(self.conv1(x)))
            out = F.relu(self.bn_conv2(self.conv2(out)))
            out = self.pool(out)
            out = self.dropout_conv(out)
            out = F.relu(self.bn_conv3(self.conv3(out)))
            out = F.relu(self.bn_conv4(self.conv4(out)))
            out = self.pool(out)
            out = self.dropout_conv(out)
            return out

        def dense_layers(self, x):
            out = F.relu(self.bn_dense1(self.fc1(x)))
            out = self.dropout(out)
            out = F.relu(self.bn_dense2(self.fc2(out)))
            out = self.dropout(out)
            out = self.fc3(out)
            return out

        def forward(self, x):
            out = self.conv_layers(x)
            out = out.view(-1, 256 * 8 * 8)
            out = self.dense_layers(out)
            return out

    return Net()


In [19]:
# T 
from torchvision.datasets import CIFAR10
from torchvision import transforms as T

def get_augmentations(train: bool = True) -> T.Compose:
    
    means = (0.49139968, 0.48215841, 0.44653091)
    stds = (0.24703223, 0.24348513, 0.26158784)

    
    if train:
        return T.Compose(
            [
                T.RandomResizedCrop(size=32, scale=(0.8, 1.1)),
                T.RandomHorizontalFlip(p=0.5),
                T.RandomAdjustSharpness(sharpness_factor=2),
                T.ToTensor(),
                T.Normalize(mean=means, std=stds)
            ]
        )
    else:
        return T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=means, std=stds)
            ]
        )


In [26]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_modified3_conv_cifar().to(device)
optimizer = Adam(model.parameters())

transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.CIFAR10(root='./data', train=True, transform=get_augmentations(train=True))
test_set = datasets.CIFAR10(root='./data', train=False, transform=get_augmentations(train=False))
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

for epoch in range(1, 300):
    train(model, device, train_loader, optimizer, epoch)
    accuracy = test(model, device, test_loader)
    print(f'Epoch: {epoch}, Accuracy: {accuracy}')
    if accuracy >= 0.95:
        break

predictions = predict_tta(model, test_loader, device)
torch.save(predictions, 'predictions.pt')


Epoch: 1, Accuracy: 0.6272
Epoch: 2, Accuracy: 0.707
Epoch: 3, Accuracy: 0.7514
Epoch: 4, Accuracy: 0.762
Epoch: 5, Accuracy: 0.7956
Epoch: 6, Accuracy: 0.7992
Epoch: 7, Accuracy: 0.8074
Epoch: 8, Accuracy: 0.8268
Epoch: 9, Accuracy: 0.8289
Epoch: 10, Accuracy: 0.8395
Epoch: 11, Accuracy: 0.8513
Epoch: 12, Accuracy: 0.8478
Epoch: 13, Accuracy: 0.8599
Epoch: 14, Accuracy: 0.8624
Epoch: 15, Accuracy: 0.8616
Epoch: 16, Accuracy: 0.8696
Epoch: 17, Accuracy: 0.8664
Epoch: 18, Accuracy: 0.8684
Epoch: 19, Accuracy: 0.8707
Epoch: 20, Accuracy: 0.8765
Epoch: 21, Accuracy: 0.8744
Epoch: 22, Accuracy: 0.8797
Epoch: 23, Accuracy: 0.8779
Epoch: 24, Accuracy: 0.8838
Epoch: 25, Accuracy: 0.8849
Epoch: 26, Accuracy: 0.8828
Epoch: 27, Accuracy: 0.8873
Epoch: 28, Accuracy: 0.8836
Epoch: 29, Accuracy: 0.8888
Epoch: 30, Accuracy: 0.8876
Epoch: 31, Accuracy: 0.8914
Epoch: 32, Accuracy: 0.8872
Epoch: 33, Accuracy: 0.8893
Epoch: 34, Accuracy: 0.8935
Epoch: 35, Accuracy: 0.8898
Epoch: 36, Accuracy: 0.8932
Epo