In [None]:
!pip install --upgrade wandb==0.22.3

# Training

---

In [None]:
from train import train
from model import EncoderNoPooling, LinearClassifier, CNN, CNNCrown
from losses import SupConLoss
from verifier import PGDVerifier

In [None]:
# used for logging

import wandb
wandb_key = ...
wandb.login(key=wandb_key)

In [None]:
import io
import umap
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [None]:
def get_device():
    if torch.cuda.is_available():
        # NVIDIA GPU
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Apple Silicon GPU (MPS)
        device = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    else:
        # Fallback to CPU
        device = torch.device("cpu")
        print("Using CPU")
    return device

In [None]:
DEVICE = get_device()
BATCH_SIZE = 64
PROJ_DIM = 128

In [None]:
class RandomGaussianNoise(nn.Module):
    def __init__(self, mean=0.0, std=0.05, p=0.5):
        super().__init__()
        self.mean = mean
        self.std = std
        self.p = p

    def forward(self, x):
        if torch.rand(1) > self.p:
            return x

        noise = torch.randn_like(x) * self.std + self.mean
        return x + noise

"""torch.manual_seed(42)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])"""

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Critical for translation
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),  # Reduce from 180Â° (too aggressive)
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Stronger
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    RandomGaussianNoise(p=0.5),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# --- Encoder Training
EPOCHS = 20
BATCH_SIZE = 128
learning_rate = 1e-3
sup_con_loss = SupConLoss()
encoder = EncoderNoPooling(in_channels=3, proj_dim=PROJ_DIM).to(DEVICE)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling Encoder - Augmentation",
    id="jmvbvb4r",
    resume="allow",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "Supervised Contrastive Loss"
    }
)

In [None]:
encoder = train(
    encoder,
    train_loader,
    validation_loader,
    encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    DEVICE,
    compute_accuracy=False,
    wandb_logging=True
)

In [None]:
encoder.eval()

all_embeddings = []
all_labels = []

N_ITERATIONS = 500 // BATCH_SIZE

# train embeddings
with torch.no_grad():
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Train)")
plt.colorbar(scatter, ticks=range(10))


# ---- validatin embeddings
all_embeddings = []
all_labels = []

with torch.no_grad():
    for i, (images, labels) in enumerate(validation_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

plt.subplot(1, 2, 2)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Validation)")
plt.colorbar(scatter, ticks=range(10))
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)

In [None]:
# logging embeddings
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "encoder_weights.pt"
torch.save(encoder.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# --- Classifier Training
EPOCHS = 10
BATCH_SIZE = 128
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
classifier = LinearClassifier(in_dim=PROJ_DIM, num_classes=10).to(DEVICE)
classifier_optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling Classifier - Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropyLoss"
    }
)

In [None]:
def execute_classifier(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(images)
    return embeddings, labels

classifier = train(
    classifier,
    train_loader,
    validation_loader,
    classifier_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=execute_classifier,
    wandb_logging=True
)

In [None]:
# logging weights
model_filename = "classifier_weights.pt"
torch.save(classifier.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# --- Full Model Training
pooling = False
EPOCHS = 30
BATCH_SIZE = 128
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
full_model = CNNCrown(in_channels=3, num_classes=10, pooling=pooling).to(DEVICE)
full_model_optimizer = optim.Adam(full_model.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling - Full Model - Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy"
    }
)

In [None]:
full_model = train(
    full_model,
    train_loader,
    validation_loader,
    full_model_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    wandb_logging=True
)

In [None]:
full_model.eval()

all_embeddings = []
all_labels = []

N_ITERATIONS = 500 // BATCH_SIZE

# train embeddings
with torch.no_grad():
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = full_model.encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Train)")
plt.colorbar(scatter, ticks=range(10))


# ---- validatin embeddings
all_embeddings = []
all_labels = []

with torch.no_grad():
    for i, (images, labels) in enumerate(validation_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = full_model.encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

plt.subplot(1, 2, 2)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Validation)")
plt.colorbar(scatter, ticks=range(10))
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)

In [None]:
# logging embeddings
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "full_model_weights.pt"
torch.save(full_model.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# --- Adversarial Training
# !!! IMPORTANT !!! Remember to halve BATCH_SIZE because we add adversarial examples

pooling = False
EPOCHS = 10
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
adversarial_model = CNNCrown(in_channels=3, proj_dim=PROJ_DIM, num_classes=10, pooling=pooling).to(DEVICE)
adversarial_model_optimizer = optim.Adam(adversarial_model.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling - Adversarial Model - No Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy"
    }
)

In [None]:
def compute_adversarial_examples(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    adversarial_examples, _, _ = pgd.verify(adversarial_model, images, labels, clamp_min=-1, clamp_max=1)
    adversarial_examples.requires_grad = False
    images = torch.cat([images, adversarial_examples])
    labels = torch.cat([labels, labels])
    return images, labels

adversarial_model = train(
    adversarial_model,
    train_loader,
    validation_loader,
    adversarial_model_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples,
    wandb_logging=True
)

In [None]:
adversarial_model.eval()

all_embeddings = []
all_labels = []

N_ITERATIONS = 500 // BATCH_SIZE

# train embeddings
with torch.no_grad():
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = adversarial_model.encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Train)")
plt.colorbar(scatter, ticks=range(10))


# ---- validatin embeddings
all_embeddings = []
all_labels = []

with torch.no_grad():
    for i, (images, labels) in enumerate(validation_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = adversarial_model.encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

plt.subplot(1, 2, 2)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Validation)")
plt.colorbar(scatter, ticks=range(10))
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)

In [None]:
# logging embeddings
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "adversarial_model_weights.pt"
torch.save(adversarial_model.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# --- Adversarial Training With Supervised Contrastive Loss
# --- Let's train the Classifier only
# !!! IMPORTANT !!! Remember to halve BATCH_SIZE because we add adversarial examples

# --- Encoder Training
EPOCHS = 50
learning_rate = 1e-4
sup_con_loss = SupConLoss(temperature=0.1)
adversarial_encoder = EncoderNoPooling(in_channels=3, proj_dim=PROJ_DIM).to(DEVICE)
adversarial_encoder_optimizer = optim.Adam(adversarial_encoder.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling - Adversarial Encoder - Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "Supervised Contrastive Loss"
    }
)

In [None]:
def compute_adversarial_examples(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    adversarial_examples, _, _ = pgd.verify(adversarial_encoder, images, labels, clamp_min=-1, clamp_max=1, criterion=sup_con_loss)
    adversarial_examples.requires_grad = False
    images = torch.cat([images, adversarial_examples])
    labels = torch.cat([labels, labels])
    return images, labels

adversarial_encoder = train(
    adversarial_encoder,
    train_loader,
    validation_loader,
    adversarial_encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples,
    compute_accuracy=False,
    wandb_logging=True,
)

In [None]:
adversarial_encoder.eval()

all_embeddings = []
all_labels = []

N_ITERATIONS = 500 // BATCH_SIZE

# train embeddings
with torch.no_grad():
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = adversarial_encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Train)")
plt.colorbar(scatter, ticks=range(10))


# ---- validatin embeddings
all_embeddings = []
all_labels = []

with torch.no_grad():
    for i, (images, labels) in enumerate(validation_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = adversarial_encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

plt.subplot(1, 2, 2)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Validation)")
plt.colorbar(scatter, ticks=range(10))
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)

In [None]:
# logging embeddings
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "adversarial_encoder_weights.pt"
torch.save(adversarial_encoder.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
EPOCHS = 10
learning_rate = 0.001
adversarial_encoder.eval()
cross_entropy_loss = nn.CrossEntropyLoss()
adversarial_classifier = LinearClassifier(in_dim=PROJ_DIM, num_classes=10).to(DEVICE)
adversarial_classifier_optimizer = optim.Adam(adversarial_classifier.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

adversarial_encoder = EncoderNoPooling(in_channels=3, proj_dim=PROJ_DIM)
adversarial_encoder.load_state_dict(torch.load("/kaggle/input/cnnrobust/pytorch/nopooling_models/3/adversarial_encoder_weights_augmented.pt", map_location=DEVICE))
adversarial_encoder.eval()

cnn = CNNCrown(pooling=False)
cnn.encoder = adversarial_encoder
cnn.classifier = adversarial_classifier
cnn.to(DEVICE)
cnn.eval()

In [None]:
wandb.init(
    project="Cnn-Verification",
    name="No Pooling - Adversarial Contrastive Classifier - Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy Loss"
    }
)

In [None]:
def compute_adversarial_examples(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    adversarial_examples, _, _ = pgd.verify(cnn, images, labels, clamp_min=-1, clamp_max=1)
    adversarial_examples.requires_grad = False
    images = torch.cat([images, adversarial_examples])
    labels = torch.cat([labels, labels])
    with torch.no_grad():
        adversarial_encoder.eval()
        embeddings = adversarial_encoder(images)
    return embeddings, labels

adversarial_classifier = train(
    adversarial_classifier,
    train_loader,
    validation_loader,
    adversarial_classifier_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples,
    wandb_logging=True,
)

In [None]:
# logging weights
model_filename = "adversarial_classifier_weights.pt"
torch.save(adversarial_classifier.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

---

# Testing

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from model import CNNCrown

In [None]:
# loading all the models
DEVICE = "cuda"

base_path = "/kaggle/input/cnnrobust/pytorch/nopooling_models/4"
augmentation_path = f"{base_path}/augmentation"
no_augmentation_path = f"{base_path}/no_augmentation"

models_weights = [
    torch.load(f"{augmentation_path}/normal_model.pt"),
    torch.load(f"{augmentation_path}/contrastive_model.pt"),
    torch.load(f"{augmentation_path}/adversarial_model.pt"),
    torch.load(f"{augmentation_path}/adversarial_contrastive_model.pt"),
    torch.load(f"{no_augmentation_path}/normal_model.pt"),
    torch.load(f"{no_augmentation_path}/contrastive_model.pt"),
    torch.load(f"{no_augmentation_path}/adversarial_model.pt"),
    torch.load(f"{no_augmentation_path}/adversarial_contrastive_model.pt"),
]

models = []

for weights in models_weights:
    model = CNNCrown(pooling=False)
    model.load_state_dict(weights)
    models.append(model)
    
models_name = ["Normal Model", "Contrastive Model", "Adversarial Model", "Adversarial Contrastive"] * 2

In [None]:
BATCH_SIZE = 64

torch.manual_seed(42)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
print("> Accuracy")
for model_id, (model, model_name) in enumerate(zip(models, models_name)):
    model.eval()
    model.to(DEVICE)
    train_accuracy = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            logits = model(images)
            
            train_accuracy += (torch.argmax(logits, dim=1) == labels).sum().item() / len(labels)

        train_accuracy = train_accuracy / len(train_loader) * 100

    test_accuracy = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            logits = model(images)
            
            test_accuracy += (torch.argmax(logits, dim=1) == labels).sum().item() / len(labels)

        test_accuracy = test_accuracy / len(test_loader) * 100
    
    del model
    
    if model_id == 0:
        print("\t- Augmentation")
    if model_id == 4:
        print("\t- No Augmentation")
        
    print(f"\t\t- {model_name}: {train_accuracy:.2f}% -> {test_accuracy:.2f}%")