In [None]:
import os
import sys
import json
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from sklearn.metrics import ConfusionMatrixDisplay
from tqdm import tqdm

sys.path.append("..")

DATA_PATH = os.path.join("..", "data")
CHECKPOINTS_PATH =  os.path.join("..", "checkpoints")
device = "cuda"

### Custom models

1. AlexNet


In [34]:
class AlexNet_32x32(nn.Module):
    def __init__(self, num_classes: int = 10, dropout: float = 0.5) -> None:
        super().__init__()
        # Convolution
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=1)

        self.conv2 = nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1)

        self.conv3 = nn.Conv2d(192, 256, kernel_size=3, stride=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.relu = nn.ReLU(inplace=True)

        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((4, 4))

        # Classification
        self.dropout1 = nn.Dropout(dropout)
        self.lin1 = nn.Linear(256 * 4 * 4, 4096)

        self.dropout2 = nn.Dropout(dropout)
        self.lin2 = nn.Linear(4096, 4096)

        self.dropout3 = nn.Dropout(dropout)
        self.lin3 = nn.Linear(4096, num_classes)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.maxpool1(self.relu(self.conv1(x)))
        x = self.maxpool2(self.relu(self.conv2(x)))
        x = self.maxpool3(self.relu(self.conv3(x)))
        x = self.maxpool4(self.relu(self.conv4(x)))

        x = self.global_avg_pool(x)

        x = torch.flatten(x, 1)

        x = self.lin1(self.dropout1(x))
        x = self.lin2(self.dropout2(x))
        x = self.lin3(self.dropout3(x))

        return x

2. ResNet

In [35]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsampler = None):
        super().__init__()
        self.downsampler = downsampler

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)


    def forward(self, x):
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity
        out = self.relu(out)
        return out


class ResNet_32x32(nn.Module):
    def __init__(self, num_classes: int = 10, block: nn.Module = ResidualBlock):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1)
        self.bnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer1 = self._residual_layer(block, 64, 128, 2)

        self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer2 = self._residual_layer(block, 128, 256, 2)

        self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=1)
        self.res_layer3 = self._residual_layer(block, 256, 512, 2)

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        self.lin = nn.Linear(512, num_classes)


    def _residual_layer(self, block, in_channels, out_channels, blocks_num, stride=1):
        """Creates a residual layer consisting out of residual blocks"""
        downsampler = None
        if stride != 1 or in_channels != out_channels:
            downsampler = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(block(in_channels, out_channels, stride, downsampler))
        for _ in range(blocks_num - 1):
            layers.append(block(out_channels, out_channels, stride))

        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.relu(self.bnorm(self.conv(x)))

        x = self.res_layer1(self.max_pool1(x))
        x = self.res_layer2(self.max_pool2(x))
        x = self.res_layer3(self.max_pool3(x))

        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.lin(x)
        return x

### Data prep

In [2]:
CINIC_MEAN = [0.47889522, 0.47227842, 0.43047404]
CINIC_STD = [0.24205776, 0.23828046, 0.25874835]


def get_dataset(
        path: str,
        batch_size: int,
        shuffle: bool,
        use_augmentations: bool,
) -> DataLoader:
    augmentations = ([
        # TODO: add augmentations
    ] if use_augmentations else [])
    transform = transforms.Compose([*augmentations,
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=CINIC_MEAN,std=CINIC_STD)])

    ds = torchvision.datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=3)
    return loader


def get_cinic(
        data_path: str,
        batch_size: int = 256
) -> tuple[DataLoader, DataLoader, DataLoader]:
    train_path = os.path.join(data_path, "train")
    valid_path = os.path.join(data_path, "valid")
    test_path = os.path.join(data_path, "test")

    cinic_train = get_dataset(train_path, batch_size, True, True)
    cinic_validation = get_dataset(valid_path, batch_size, False, False)
    cinic_test = get_dataset(test_path, batch_size, False, False)

    return cinic_train, cinic_validation, cinic_test

### Utils

In [3]:
import numpy as np
import torch
import random


def set_seed(seed: int):
    random.seed(seed)
    np.random.RandomState(seed)
    torch.manual_seed(seed)


def get_device():
    return "cuda:0" if torch.cuda.device_count() > 0 else 'cpu'

### Train

In [4]:
MODELS = {
    "AlexNet": AlexNet_32x32,
    "ResNet": ResNet_32x32
}

OPTIMIZERS = {
    "SGD": torch.optim.SGD,
    "AdamW": torch.optim.AdamW
}


def train(config: dict, data_path: str):
    # Set seed for reproducibility
    seed = int(config["seed"]) if "seed" in config else 0
    set_seed(seed)

    device = get_device()
    print("Device: ", device)

    checkpoint = config["checkpoint_folder"]
    os.makedirs(checkpoint, exist_ok=True)

    with open(os.path.join(checkpoint, "config.json"), "w") as f:
        f.write(json.dumps(config))

    batch_size = int(config["batch_size"]) if "batch_size" in config else 256
    cinic_train, cinic_valid, cinic_test = get_cinic(data_path, batch_size=batch_size)
    model: nn.Module = MODELS[config["model"]](**config["model_params"]).to(device)
    optimizer: torch.optim.Optimizer = OPTIMIZERS[config["optimizer"]](params=model.parameters(),
                                                **config["optimizer_params"])
    criterion = nn.CrossEntropyLoss()

    epochs = int(config["epochs"])
    metrics = {"loss": [float("+inf")], "accuracy": [0.0], "val_loss": [float("+inf")], "val_accuracy": [0.0]}
    best_model = model
    warmup_epochs = int(config["warmup_epochs"])

    best_loss = float('inf')
    best_loss_epoch = -1

    classes = cinic_test.dataset.classes
    classes = sorted(classes, key=lambda x: cinic_test.dataset.class_to_idx[x])

    # Training
    for epoch in range(epochs):
        print("Epoch", epoch)

        print("Processing training")
        model.train()
        accuracy, loss = train_epoch(model, cinic_train, optimizer, criterion)
        metrics["accuracy"].append(accuracy)
        metrics["loss"].append(loss)

        print("Processing validation")
        model.eval()
        val_accuracy, val_loss = valid_epoch(model, cinic_valid, criterion)
        metrics["val_accuracy"].append(val_accuracy)
        metrics["val_loss"].append(val_loss)

        print(f"Loss: {metrics['loss'][-1]:.4f}", end=' ')
        print(f"Accuracy: {metrics['accuracy'][-1]:.4f}", end=' ')
        print(f"Validation Loss: {metrics['val_loss'][-1]:.4f}", end=' ')
        print(f"Validation Accuracy: {metrics['val_accuracy'][-1]:.4f}")

        # Saving checkpoint
        if epoch >= warmup_epochs and val_loss < metrics["val_loss"][-2]:
            best_model = model
            torch.save({
                    "model_state": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "loss": criterion.state_dict(),
                    "epoch": epoch
            }, os.path.join(checkpoint, "state.pth"))

        # Early stopping
        if "early_stopping" in config:
            if val_loss < best_loss:
                best_loss_epoch = epoch
                best_loss = val_loss
            elif epoch - best_loss_epoch >= config["early_stopping"]["patience"]:
                print("Early stopping!")
                break

    # Calculate test metrics and confusion matrix
    test_accuracy, test_loss, test_targets, test_predictions = test_epoch(best_model,
                                                                          cinic_test, criterion)
    with open(os.path.join(checkpoint, "test_metrics.txt"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f} Test Accuracy: {test_accuracy:.4f}")

    # Confusion matrix
    disp = ConfusionMatrixDisplay.from_predictions(test_targets, test_predictions,
                                                   display_labels=classes)
    disp.figure_.savefig(os.path.join(checkpoint, "confusion_matrix_test.jpg"))

    # Confusion matrix without diagonal
    wrong_predictions_idx = test_targets != test_predictions
    test_targets = test_targets[wrong_predictions_idx]
    test_predictions = test_predictions[wrong_predictions_idx]
    disp = ConfusionMatrixDisplay.from_predictions(test_targets, test_predictions,
                                                   display_labels=classes)
    disp.figure_.savefig(os.path.join(checkpoint, "confusion_matrix_test_no_diag.jpg"))

    df = pd.DataFrame(metrics)
    df.to_csv(os.path.join(checkpoint, "metrics.csv"))


def train_epoch(
    model: nn.Module,
    train_ds: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.CrossEntropyLoss
):
    losses = []
    accuracies = []
    batch_sizes = []
    for input, target in tqdm(train_ds):
        optimizer.zero_grad()

        input, target = input.cuda(), target.cuda()
        ohe_target = F.one_hot(target, 10).type(torch.float32)
        output = model(input)
        loss = criterion(output, ohe_target)

        loss.backward()
        optimizer.step()

        pred = torch.argmax(output, dim=-1)
        accurate_pred = (pred == target).type(torch.float32)
        accuracies.append(torch.mean(accurate_pred).item())

        losses.append(loss.item())
        batch_sizes.append(len(input))

    accuracy = np.average(accuracies, weights=batch_sizes)
    loss = np.average(losses, weights=batch_sizes)
    return accuracy, loss


def valid_epoch(
    model: nn.Module,
    valid_ds: DataLoader,
    criterion: nn.CrossEntropyLoss
):
    val_losses = []
    val_accuracies = []
    batch_sizes = []
    with torch.no_grad():
        for input, target in tqdm(valid_ds):
            input, target = input.cuda(), target.cuda()
            ohe_target = F.one_hot(target, 10).type(torch.float32)
            output = model(input)
            val_loss = criterion(output, ohe_target)

            # Calculate accuracy
            pred = torch.argmax(output, dim=-1)
            accurate_pred = (pred == target).type(torch.float32)
            val_accuracies.append(torch.mean(accurate_pred).item())

            val_losses.append(val_loss.item())
            # Store information regarding
            batch_sizes.append(len(input))

    val_accuracy = np.average(val_accuracies, weights=batch_sizes)
    val_loss = np.average(val_losses, weights=batch_sizes)
    return val_accuracy, val_loss


def test_epoch(
    model: nn.Module,
    valid_ds: DataLoader,
    criterion: nn.CrossEntropyLoss
):
    val_losses = []
    val_accuracies = []
    batch_sizes = []
    predictions = []
    targets = []

    with torch.no_grad():
        for input, target in tqdm(valid_ds):
            input, target = input.cuda(), target.cuda()
            ohe_target = F.one_hot(target, 10).type(torch.float32)
            output = model(input)
            val_loss = criterion(output, ohe_target)

            # Calculate accuracy
            pred = torch.argmax(output, dim=-1)
            accurate_pred = (pred == target).type(torch.float32)
            val_accuracies.append(torch.mean(accurate_pred).item())

            val_losses.append(val_loss.item())
            # Store information regarding
            batch_sizes.append(len(input))

            predictions.append(pred)
            targets.append(target)

    val_accuracy = np.average(val_accuracies, weights=batch_sizes)
    val_loss = np.average(val_losses, weights=batch_sizes)
    predictions = torch.concatenate(predictions).cpu().numpy()
    targets = torch.concatenate(targets).cpu().numpy()
    return val_accuracy, val_loss, targets, predictions

In [5]:
cinic_train, cinic_val, cinic_test = get_cinic(DATA_PATH, 256)
checkpoint_path = os.path.join(CHECKPOINTS_PATH, "ResNet_lr_0_001")
state = torch.load(os.path.join(checkpoint_path, "state.pth"), weights_only=False)

In [6]:
state.keys()

dict_keys(['model_state', 'optimizer', 'loss', 'epoch'])

In [7]:
model = ResNet_32x32()
model.load_state_dict(state["model_state"])

<All keys matched successfully>

In [31]:
def evaluate(model: nn.Module, loader: DataLoader,):
    model = model.cuda()
    accuracies = []
    batch_sizes = []
    with torch.no_grad():
        for input, target in tqdm(loader):
            input, target = input.cuda(), target.cuda()
            output = model(input)
            predictions= torch.argmax(output, axis=-1)
            accuracies.append((torch.sum(predictions == target) / len(target)).reshape((-1)))
            batch_sizes.append(len(target))

    return accuracies, torch.mean(torch.concatenate(accuracies))

In [14]:
model.eval()
accuracies = evaluate(model, cinic_test)

  1%|          | 4/352 [00:08<09:44,  1.68s/it]

Accuracy:  0.8359375
Accuracy:  0.828125
Accuracy:  0.87109375
Accuracy:  0.84375
Accuracy:  0.8359375
Accuracy:  0.84375


  3%|▎         | 11/352 [00:09<02:20,  2.43it/s]

Accuracy:  0.78515625
Accuracy:  0.859375
Accuracy:  0.9296875
Accuracy:  0.89453125
Accuracy:  0.90625
Accuracy:  0.8125
Accuracy:  0.7578125


  5%|▌         | 19/352 [00:09<00:56,  5.89it/s]

Accuracy:  0.6875
Accuracy:  0.7109375
Accuracy:  0.734375
Accuracy:  0.83984375
Accuracy:  0.7890625
Accuracy:  0.71875
Accuracy:  0.62109375


  7%|▋         | 26/352 [00:09<00:31, 10.20it/s]

Accuracy:  0.86328125
Accuracy:  0.86328125
Accuracy:  0.91015625
Accuracy:  0.93359375
Accuracy:  0.859375
Accuracy:  0.83984375


  8%|▊         | 29/352 [00:09<00:27, 11.92it/s]

Accuracy:  0.921875
Accuracy:  0.8125
Accuracy:  0.80859375
Accuracy:  0.75
Accuracy:  0.71875
Accuracy:  0.7890625


 10%|▉         | 35/352 [00:09<00:19, 16.02it/s]

Accuracy:  0.8359375
Accuracy:  0.87109375
Accuracy:  0.96484375
Accuracy:  0.94921875
Accuracy:  0.921875


 12%|█▏        | 42/352 [00:10<00:15, 19.48it/s]

Accuracy:  0.9296875
Accuracy:  0.9296875
Accuracy:  0.9453125
Accuracy:  0.9609375
Accuracy:  0.921875
Accuracy:  0.9140625


 14%|█▎        | 48/352 [00:10<00:13, 22.00it/s]

Accuracy:  0.51171875
Accuracy:  0.82421875
Accuracy:  0.76171875
Accuracy:  0.75390625
Accuracy:  0.8671875
Accuracy:  0.8984375


 15%|█▌        | 54/352 [00:10<00:12, 23.28it/s]

Accuracy:  0.9296875
Accuracy:  0.87109375
Accuracy:  0.8359375
Accuracy:  0.6640625
Accuracy:  0.91796875


 16%|█▌        | 57/352 [00:10<00:12, 23.81it/s]

Accuracy:  0.75390625
Accuracy:  0.5078125
Accuracy:  0.69140625
Accuracy:  0.76953125
Accuracy:  0.765625


 18%|█▊        | 63/352 [00:11<00:11, 24.18it/s]

Accuracy:  0.6796875
Accuracy:  0.82421875
Accuracy:  0.71484375
Accuracy:  0.921875
Accuracy:  0.90234375
Accuracy:  0.875


 20%|█▉        | 69/352 [00:11<00:13, 21.76it/s]

Accuracy:  0.91015625
Accuracy:  0.87109375
Accuracy:  0.6796875
Accuracy:  0.80078125
Accuracy:  0.83203125


 21%|██▏       | 75/352 [00:11<00:12, 22.67it/s]

Accuracy:  0.75
Accuracy:  0.74609375
Accuracy:  0.74609375
Accuracy:  0.74609375
Accuracy:  0.7265625


 22%|██▏       | 78/352 [00:11<00:11, 23.74it/s]

Accuracy:  0.75390625
Accuracy:  0.7578125
Accuracy:  0.73828125
Accuracy:  0.609375
Accuracy:  0.73828125


 24%|██▍       | 84/352 [00:12<00:11, 23.32it/s]

Accuracy:  0.75
Accuracy:  0.78515625
Accuracy:  0.73828125
Accuracy:  0.8046875
Accuracy:  0.8203125
Accuracy:  0.73046875


 26%|██▌       | 90/352 [00:12<00:11, 22.20it/s]

Accuracy:  0.78125
Accuracy:  0.7734375
Accuracy:  0.65234375
Accuracy:  0.49609375
Accuracy:  0.39453125


 27%|██▋       | 96/352 [00:12<00:11, 23.01it/s]

Accuracy:  0.609375
Accuracy:  0.671875
Accuracy:  0.65234375
Accuracy:  0.703125
Accuracy:  0.6484375


 28%|██▊       | 99/352 [00:12<00:11, 22.79it/s]

Accuracy:  0.82421875
Accuracy:  0.69140625
Accuracy:  0.7578125
Accuracy:  0.70703125
Accuracy:  0.73046875


 30%|██▉       | 105/352 [00:12<00:10, 23.92it/s]

Accuracy:  0.796875
Accuracy:  0.71484375
Accuracy:  0.77734375
Accuracy:  0.65625
Accuracy:  0.68359375
Accuracy:  0.67578125


 32%|███▏      | 111/352 [00:13<00:09, 24.92it/s]

Accuracy:  0.6875
Accuracy:  0.69921875
Accuracy:  0.71875
Accuracy:  0.73828125
Accuracy:  0.69921875
Accuracy:  0.7265625


 33%|███▎      | 117/352 [00:13<00:09, 24.84it/s]

Accuracy:  0.5703125
Accuracy:  0.6171875
Accuracy:  0.6171875
Accuracy:  0.578125
Accuracy:  0.6015625
Accuracy:  0.61328125


 35%|███▍      | 123/352 [00:13<00:08, 26.12it/s]

Accuracy:  0.6796875
Accuracy:  0.72265625
Accuracy:  0.58984375
Accuracy:  0.5625
Accuracy:  0.6015625
Accuracy:  0.49609375


 37%|███▋      | 129/352 [00:13<00:09, 24.52it/s]

Accuracy:  0.67578125
Accuracy:  0.53125
Accuracy:  0.60546875
Accuracy:  0.58984375
Accuracy:  0.55859375
Accuracy:  0.61328125


 38%|███▊      | 135/352 [00:14<00:08, 25.53it/s]

Accuracy:  0.62109375
Accuracy:  0.5859375
Accuracy:  0.5546875
Accuracy:  0.56640625
Accuracy:  0.671875
Accuracy:  0.578125


 40%|████      | 141/352 [00:14<00:08, 25.30it/s]

Accuracy:  0.6484375
Accuracy:  0.60546875
Accuracy:  0.58203125
Accuracy:  0.546875
Accuracy:  0.7578125
Accuracy:  0.78515625


 43%|████▎     | 150/352 [00:14<00:07, 27.35it/s]

Accuracy:  0.82421875
Accuracy:  0.7890625
Accuracy:  0.75390625
Accuracy:  0.76953125
Accuracy:  0.78515625
Accuracy:  0.58984375
Accuracy:  0.47265625


 43%|████▎     | 153/352 [00:14<00:08, 24.85it/s]

Accuracy:  0.4453125
Accuracy:  0.5859375
Accuracy:  0.640625
Accuracy:  0.40625
Accuracy:  0.41015625


 45%|████▌     | 159/352 [00:15<00:07, 24.70it/s]

Accuracy:  0.23828125
Accuracy:  0.4765625
Accuracy:  0.64453125
Accuracy:  0.71875
Accuracy:  0.65234375
Accuracy:  0.6015625


 47%|████▋     | 165/352 [00:15<00:07, 25.22it/s]

Accuracy:  0.70703125
Accuracy:  0.8203125
Accuracy:  0.78125
Accuracy:  0.69921875
Accuracy:  0.67578125
Accuracy:  0.68359375


 49%|████▊     | 171/352 [00:15<00:07, 23.50it/s]

Accuracy:  0.7578125
Accuracy:  0.7734375
Accuracy:  0.765625
Accuracy:  0.73828125
Accuracy:  0.5703125
Accuracy:  0.48046875


 50%|█████     | 177/352 [00:15<00:07, 24.32it/s]

Accuracy:  0.4921875
Accuracy:  0.359375
Accuracy:  0.515625
Accuracy:  0.796875
Accuracy:  0.8046875
Accuracy:  0.78125


 52%|█████▏    | 183/352 [00:16<00:06, 26.05it/s]

Accuracy:  0.77734375
Accuracy:  0.78515625
Accuracy:  0.77734375
Accuracy:  0.796875
Accuracy:  0.6640625
Accuracy:  0.59765625


 54%|█████▎    | 189/352 [00:16<00:06, 25.85it/s]

Accuracy:  0.5546875
Accuracy:  0.56640625
Accuracy:  0.65625
Accuracy:  0.77734375
Accuracy:  0.5
Accuracy:  0.5859375


 55%|█████▌    | 195/352 [00:16<00:06, 24.76it/s]

Accuracy:  0.5390625
Accuracy:  0.52734375
Accuracy:  0.68359375
Accuracy:  0.68359375
Accuracy:  0.53125


 57%|█████▋    | 201/352 [00:16<00:05, 26.33it/s]

Accuracy:  0.515625
Accuracy:  0.56640625
Accuracy:  0.6484375
Accuracy:  0.625
Accuracy:  0.671875
Accuracy:  0.609375


 59%|█████▉    | 207/352 [00:16<00:05, 26.34it/s]

Accuracy:  0.63671875
Accuracy:  0.625
Accuracy:  0.25
Accuracy:  0.203125
Accuracy:  0.30078125
Accuracy:  0.32421875


 61%|██████    | 214/352 [00:17<00:05, 26.94it/s]

Accuracy:  0.2578125
Accuracy:  0.2734375
Accuracy:  0.29296875
Accuracy:  0.921875
Accuracy:  0.93359375
Accuracy:  0.90625


 62%|██████▏   | 217/352 [00:17<00:05, 26.23it/s]

Accuracy:  0.92578125
Accuracy:  0.8984375
Accuracy:  0.9375
Accuracy:  0.9140625
Accuracy:  0.88671875


 63%|██████▎   | 223/352 [00:17<00:04, 26.12it/s]

Accuracy:  0.75390625
Accuracy:  0.81640625
Accuracy:  0.89453125
Accuracy:  0.89453125
Accuracy:  0.86328125
Accuracy:  0.91015625


 65%|██████▌   | 229/352 [00:17<00:04, 25.57it/s]

Accuracy:  0.890625
Accuracy:  0.81640625
Accuracy:  0.796875
Accuracy:  0.8828125
Accuracy:  0.87109375
Accuracy:  0.78515625


 67%|██████▋   | 235/352 [00:18<00:04, 25.83it/s]

Accuracy:  0.890625
Accuracy:  0.9296875
Accuracy:  0.91015625
Accuracy:  0.859375
Accuracy:  0.9296875
Accuracy:  0.90625


 68%|██████▊   | 241/352 [00:18<00:04, 24.28it/s]

Accuracy:  0.90234375
Accuracy:  0.87890625
Accuracy:  0.97265625
Accuracy:  0.8125
Accuracy:  0.8515625


 70%|███████   | 247/352 [00:18<00:04, 25.47it/s]

Accuracy:  0.875
Accuracy:  0.859375
Accuracy:  0.890625
Accuracy:  0.609375
Accuracy:  0.8203125
Accuracy:  0.85546875
Accuracy:  0.81640625


 72%|███████▏  | 253/352 [00:18<00:03, 25.81it/s]

Accuracy:  0.83203125
Accuracy:  0.79296875
Accuracy:  0.8359375
Accuracy:  0.8046875
Accuracy:  0.80859375
Accuracy:  0.7109375


 74%|███████▎  | 259/352 [00:18<00:03, 25.93it/s]

Accuracy:  0.76171875
Accuracy:  0.8046875
Accuracy:  0.72265625
Accuracy:  0.78515625
Accuracy:  0.7421875
Accuracy:  0.8046875


 75%|███████▌  | 265/352 [00:19<00:03, 23.98it/s]

Accuracy:  0.8359375
Accuracy:  0.78515625
Accuracy:  0.734375
Accuracy:  0.74609375
Accuracy:  0.73046875


 77%|███████▋  | 271/352 [00:19<00:03, 24.18it/s]

Accuracy:  0.6328125
Accuracy:  0.79296875
Accuracy:  0.75
Accuracy:  0.6171875
Accuracy:  0.77734375
Accuracy:  0.796875


 79%|███████▊  | 277/352 [00:19<00:02, 25.03it/s]

Accuracy:  0.8125
Accuracy:  0.80078125
Accuracy:  0.703125
Accuracy:  0.77734375
Accuracy:  0.71484375
Accuracy:  0.7578125


 80%|████████  | 283/352 [00:19<00:02, 24.47it/s]

Accuracy:  0.8125
Accuracy:  0.87109375
Accuracy:  0.78515625
Accuracy:  0.88671875
Accuracy:  0.87890625
Accuracy:  0.90234375


 82%|████████▏ | 289/352 [00:20<00:02, 23.82it/s]

Accuracy:  0.9296875
Accuracy:  0.890625
Accuracy:  0.8828125
Accuracy:  0.91015625
Accuracy:  0.88671875


 83%|████████▎ | 292/352 [00:20<00:02, 23.14it/s]

Accuracy:  0.6484375
Accuracy:  0.75390625
Accuracy:  0.7734375
Accuracy:  0.71875


 85%|████████▍ | 298/352 [00:20<00:02, 23.31it/s]

Accuracy:  0.6015625
Accuracy:  0.76171875
Accuracy:  0.6796875
Accuracy:  0.83984375
Accuracy:  0.7890625
Accuracy:  0.7890625


 86%|████████▋ | 304/352 [00:20<00:02, 22.95it/s]

Accuracy:  0.75
Accuracy:  0.7265625
Accuracy:  0.80078125
Accuracy:  0.76171875
Accuracy:  0.70703125


 87%|████████▋ | 307/352 [00:21<00:01, 23.94it/s]

Accuracy:  0.82421875
Accuracy:  0.76953125
Accuracy:  0.7578125
Accuracy:  0.828125
Accuracy:  0.78125


 89%|████████▉ | 313/352 [00:21<00:01, 23.66it/s]

Accuracy:  0.7109375
Accuracy:  0.8125
Accuracy:  0.7578125
Accuracy:  0.75
Accuracy:  0.78515625


 91%|█████████ | 319/352 [00:21<00:01, 23.53it/s]

Accuracy:  0.81640625
Accuracy:  0.7890625
Accuracy:  0.8828125
Accuracy:  0.87890625
Accuracy:  0.9140625
Accuracy:  0.8515625


 92%|█████████▏| 325/352 [00:21<00:01, 24.81it/s]

Accuracy:  0.87109375
Accuracy:  0.90625
Accuracy:  0.88671875
Accuracy:  0.92578125
Accuracy:  0.75
Accuracy:  0.671875


 94%|█████████▍| 331/352 [00:22<00:00, 23.93it/s]

Accuracy:  0.640625
Accuracy:  0.7265625
Accuracy:  0.73046875
Accuracy:  0.765625
Accuracy:  0.73046875
Accuracy:  0.73046875


 95%|█████████▍| 334/352 [00:22<00:00, 24.50it/s]

Accuracy:  0.78125
Accuracy:  0.86328125
Accuracy:  0.765625
Accuracy:  0.796875
Accuracy:  0.68359375


 97%|█████████▋| 340/352 [00:22<00:00, 21.77it/s]

Accuracy:  0.3984375
Accuracy:  0.21484375
Accuracy:  0.3671875
Accuracy:  0.390625


 98%|█████████▊| 346/352 [00:22<00:00, 22.38it/s]

Accuracy:  0.6171875
Accuracy:  0.6484375
Accuracy:  0.52734375
Accuracy:  0.64453125
Accuracy:  0.71875
Accuracy:  0.81640625


 99%|█████████▉| 349/352 [00:22<00:00, 22.79it/s]

Accuracy:  0.828125
Accuracy:  0.80859375
Accuracy:  0.703125
Accuracy:  0.609375
Accuracy:  0.625


100%|██████████| 352/352 [00:23<00:00, 14.96it/s]


In [None]:
torch.mean(torch.concatenate([acc.reshape((-1)) for acc in accuracies]))

tensor(0.7342, device='cuda:0')

In [37]:
model.eval()
accuracies_train, acc_train = evaluate(model, cinic_train)
acc_train

100%|██████████| 352/352 [03:10<00:00,  1.85it/s]


tensor(0.8439, device='cuda:0')

In [33]:
accuracies_val, acc_val = evaluate(model, cinic_val)
acc_val

100%|██████████| 352/352 [02:56<00:00,  2.00it/s]


tensor(0.7361, device='cuda:0')