In [1]:
import numpy as np
import torch
import torchvision
import tqdm
from torch import nn
from torch.nn import functional as F
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from torchvision import transforms, datasets

In [2]:
import random

def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_random_seed(42)

In [3]:
mean, std = 0.5, 0.5

print(mean, std)
data_transform = transforms.Compose([
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        transforms.Resize((64, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean,
                             std=std)
    ])

data_transform_test = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean,
                             std=std)
    ])


full_dataset = datasets.ImageFolder(root='data/prepared', transform=data_transform)

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                    batch_size=256, shuffle=True,
                                    num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                    batch_size=512, shuffle=True,
                                    num_workers=2)

0.5 0.5


In [4]:
# Just very simple sanity checks
assert isinstance(train_dataset[0], tuple)
assert len(train_dataset[0]) == 2
assert isinstance(train_dataset[1][1], int)
print("tests passed")

tests passed


In [10]:
from tqdm import tqdm

def train_one_epoch(model, train_dataloader, criterion, optimizer, device="cuda:0"):
    model = model.to(device).train()
    idx = 0
    progress_bar = tqdm(train_dataloader)
    for (images, labels) in (progress_bar):
        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        progress_bar.set_description("Loss = {:.4f}".format(loss.item()))
        idx += 1

def predict(model, test_dataloder, criterion, device="cuda:0"):
    model.eval()
    model = model.to(device).eval()
    losses = []
    predicted_classes = []
    true_classes = []
    with torch.no_grad():
        for idx, (images, labels) in enumerate((test_dataloder)): 
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            losses.append(loss)
            predicted_classes.append(preds.argmax(1))
            true_classes.append(labels)
    predicted_classes = torch.cat(predicted_classes).flatten()
    true_classes = torch.cat(true_classes).flatten()
    losses = torch.Tensor(losses)
    return losses, predicted_classes, true_classes


def train(model, train_dataloader, val_dataloader, criterion, optimizer, device="cuda:0", n_epochs=10, scheduler=None):
    model.to(device)
    history_tr, history_val = [], []
    for epoch in range(n_epochs):
        train_one_epoch(model, train_dataloader, criterion, optimizer, device)
#         losses, prd_cls, true_cls = predict(model, train_dataloader_part, criterion, device)
#         train_loss = (prd_cls == true_cls).float().mean().item()
        train_loss = 0
        print(f"Train loss: {train_loss:.4f}")
        losses, prd_cls, true_cls = predict(model, val_dataloader, criterion, device)
        val_loss = (prd_cls == true_cls).float().mean().item()
        print(f"Validation loss: {val_loss:.4f}")
        history_tr.append(train_loss)
        history_val.append(val_loss)

        if scheduler is not None:
            scheduler.step()
    return history_tr, history_val

In [11]:
class AlexResNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv1 = AlexResNet.block(in_channels, 64)
        self.conv2 = AlexResNet.block_with_pooling(64, 128)
        self.layer1 = AlexResNet.residual_block(128)
        self.conv3 = AlexResNet.block_with_pooling(128, 256)
        self.conv4 = AlexResNet.block_with_pooling(256, 256)
        self.layer2 = AlexResNet.residual_block(256)
        self.conv5 = AlexResNet.block_with_pooling(256, 512)
        self.conv6 = AlexResNet.block_with_pooling(512, 512)
        self.layer3 = AlexResNet.residual_block(512)
        self.clf =  AlexResNet.classifier(512, num_classes)
        
    def forward(self, inp):
        out = self.conv2(self.conv1(inp))
        out = self.layer1(out) + out
        out = self.conv4(self.conv3(out))
        out = self.layer2(out) + out
        out = self.conv6(self.conv5(out))
        out = self.layer3(out) + out
        out = self.clf(out)
        return out

    @staticmethod
    def block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    @staticmethod
    def block_with_pooling(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    
    @staticmethod
    def residual_block(channels):
        return nn.Sequential(
            AlexResNet.block(channels, channels),
            AlexResNet.block(channels, channels)
        )
        
    @staticmethod
    def classifier(in_channels, n_classes):
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels, n_classes)
        )

In [12]:
n_classes = 10
model = AlexResNet(3, n_classes)
n_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
predict(model, val_dataloader, criterion, device=device)

KeyboardInterrupt: 

In [None]:
all_losses, predicted_labels, true_labels = predict(model, val_dataloader, criterion, device)
assert len(predicted_labels) == len(val_dataset)
accuracy = accuracy_score(predicted_labels.cpu(), true_labels.cpu())
print("tests passed")

In [None]:
history_tr, history_val = train(
    model,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    device,
    n_epochs,
    scheduler
)