In [None]:
!pip install -q torch torchvision tqdm

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import math
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:
# CELL 2 — configuration / hyperparameters
IMG_SIZE = 32            # CIFAR-10 native size
PATCH_SIZE = 4           # 32 / 4 = 8 -> 8x8 patches = 64 patches
EMBED_DIM = 384          # match your DALLE-inspired dim (keep if GPU can handle)
NUM_HEADS = 6
NUM_LAYERS = 6
MLP_RATIO = 4
DROPOUT = 0.1

NUM_CLASSES = 10
BATCH_SIZE = 256         # reduce if OOM
EPOCHS = 60
LR = 3e-4
WEIGHT_DECAY = 0.05
SAVE_PATH = "dalle_cifar10_best.pth"

# Mixed precision toggle
USE_AMP = True

print(f"Config: IMG={IMG_SIZE}, PATCH={PATCH_SIZE}, EMBED={EMBED_DIM}, HEADS={NUM_HEADS}, LAYERS={NUM_LAYERS}")


Config: IMG=32, PATCH=4, EMBED=384, HEADS=6, LAYERS=6


In [None]:
# CELL 3 — model (DALLE-style Vision Transformer adapted for CIFAR-10)
# This follows the same architecture pattern in your project. See train_new.py for reference. :contentReference[oaicite:1]{index=1}

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=384):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)                      # (B, embed_dim, H/ps, W/ps)
        x = x.flatten(2).transpose(1, 2)      # (B, num_patches, embed_dim)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]                     # each: (B, heads, N, head_dim)
        attn = (q @ k.transpose(-2, -1)) * self.scale        # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B, N, C)       # (B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MLP(nn.Module):
    def __init__(self, embed_dim=384, mlp_ratio=4, dropout=0.0):
        super().__init__()
        hidden = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class DALLEInspiredViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, embed_dim=384, num_heads=6, num_layers=6,
                 mlp_ratio=4, dropout=0.0, num_classes=10):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # weight init similar to train_new.py
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:, 0])


In [None]:
# CELL 4 — data loaders for CIFAR-10
# Use CIFAR-10 stats
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# Download dataset (will download to /root/.cache by default)
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

# Create small validation split from training set
VAL_SPLIT = 5000
train_size = len(train_dataset) - VAL_SPLIT
train_set, val_set = torch.utils.data.random_split(train_dataset, [train_size, VAL_SPLIT])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_dataset,batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print("Train/Val/Test sizes:", len(train_set), len(val_set), len(test_dataset))


100%|██████████| 170M/170M [00:24<00:00, 6.88MB/s]


Train/Val/Test sizes: 45000 5000 10000




In [None]:
# CELL 5 — training helpers (train/validate)
model = DALLEInspiredViT(
    img_size=IMG_SIZE, patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM, num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS, mlp_ratio=MLP_RATIO,
    dropout=DROPOUT, num_classes=NUM_CLASSES
).to(device)

print("Params:", sum(p.numel() for p in model.parameters()))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

scaler = torch.cuda.amp.GradScaler() if (USE_AMP and device.type == "cuda") else None

def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [train]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
        pbar.set_postfix(loss=running_loss/total, acc=100*correct/total)
    return running_loss/total, 100*correct/total

@torch.no_grad()
def evaluate(loader, stage="val"):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in tqdm(loader, desc=f"[{stage}]"):
        imgs, labels = imgs.to(device), labels.to(device)
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
    return total_loss / total, 100*correct/total


Params: 10695562


  scaler = torch.cuda.amp.GradScaler() if (USE_AMP and device.type == "cuda") else None


In [None]:
# CELL 6 — training loop and checkpointing
best_val_acc = 0.0
history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = evaluate(val_loader, stage="val")
    scheduler.step()

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch+1}: train_loss={train_loss:.4f} train_acc={train_acc:.2f}% | val_loss={val_loss:.4f} val_acc={val_acc:.2f}%")

    # save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "val_acc": val_acc
        }, SAVE_PATH)
        print(f"[SAVED] New best val_acc: {best_val_acc:.2f}% -> {SAVE_PATH}")

# final test using best checkpoint
ckpt = torch.load(SAVE_PATH, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
test_loss, test_acc = evaluate(test_loader, stage="test")
print(f"Test Accuracy (best checkpoint): {test_acc:.2f}%")


  with torch.cuda.amp.autocast():
Epoch 1/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.20it/s, acc=30.9, loss=1.84]
  with torch.cuda.amp.autocast():
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.21it/s]


Epoch 1: train_loss=1.8418 train_acc=30.93% | val_loss=1.8570 val_acc=32.34%
[SAVED] New best val_acc: 32.34% -> dalle_cifar10_best.pth


Epoch 2/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.45it/s, acc=43.1, loss=1.55]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.20it/s]


Epoch 2: train_loss=1.5478 train_acc=43.06% | val_loss=1.4413 val_acc=47.10%
[SAVED] New best val_acc: 47.10% -> dalle_cifar10_best.pth


Epoch 3/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.33it/s, acc=48.9, loss=1.4]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.93it/s]


Epoch 3: train_loss=1.4025 train_acc=48.90% | val_loss=1.3164 val_acc=51.90%
[SAVED] New best val_acc: 51.90% -> dalle_cifar10_best.pth


Epoch 4/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.34it/s, acc=52.5, loss=1.31]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.80it/s]


Epoch 4: train_loss=1.3066 train_acc=52.46% | val_loss=1.2679 val_acc=54.36%
[SAVED] New best val_acc: 54.36% -> dalle_cifar10_best.pth


Epoch 5/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.42it/s, acc=55.4, loss=1.24]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.09it/s]


Epoch 5: train_loss=1.2368 train_acc=55.36% | val_loss=1.2203 val_acc=55.02%
[SAVED] New best val_acc: 55.02% -> dalle_cifar10_best.pth


Epoch 6/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=57.6, loss=1.18]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.67it/s]


Epoch 6: train_loss=1.1780 train_acc=57.55% | val_loss=1.1433 val_acc=58.10%
[SAVED] New best val_acc: 58.10% -> dalle_cifar10_best.pth


Epoch 7/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.30it/s, acc=59.5, loss=1.12]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.05it/s]


Epoch 7: train_loss=1.1248 train_acc=59.53% | val_loss=1.1220 val_acc=59.62%
[SAVED] New best val_acc: 59.62% -> dalle_cifar10_best.pth


Epoch 8/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.36it/s, acc=60.6, loss=1.09]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.12it/s]


Epoch 8: train_loss=1.0895 train_acc=60.64% | val_loss=1.0500 val_acc=62.06%
[SAVED] New best val_acc: 62.06% -> dalle_cifar10_best.pth


Epoch 9/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.40it/s, acc=62.7, loss=1.04]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.41it/s]


Epoch 9: train_loss=1.0431 train_acc=62.74% | val_loss=1.0427 val_acc=63.10%
[SAVED] New best val_acc: 63.10% -> dalle_cifar10_best.pth


Epoch 10/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.33it/s, acc=64.1, loss=1.01]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.21it/s]


Epoch 10: train_loss=1.0062 train_acc=64.13% | val_loss=1.0294 val_acc=63.50%
[SAVED] New best val_acc: 63.50% -> dalle_cifar10_best.pth


Epoch 11/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.40it/s, acc=65, loss=0.977]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.21it/s]


Epoch 11: train_loss=0.9772 train_acc=64.96% | val_loss=0.9781 val_acc=64.70%
[SAVED] New best val_acc: 64.70% -> dalle_cifar10_best.pth


Epoch 12/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=66.4, loss=0.937]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.91it/s]


Epoch 12: train_loss=0.9373 train_acc=66.39% | val_loss=0.9591 val_acc=65.24%
[SAVED] New best val_acc: 65.24% -> dalle_cifar10_best.pth


Epoch 13/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=67.6, loss=0.911]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.80it/s]


Epoch 13: train_loss=0.9110 train_acc=67.63% | val_loss=0.9618 val_acc=66.24%
[SAVED] New best val_acc: 66.24% -> dalle_cifar10_best.pth


Epoch 14/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.36it/s, acc=68.5, loss=0.882]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.07it/s]


Epoch 14: train_loss=0.8823 train_acc=68.47% | val_loss=0.9402 val_acc=68.04%
[SAVED] New best val_acc: 68.04% -> dalle_cifar10_best.pth


Epoch 15/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=69.5, loss=0.857]
[val]: 100%|██████████| 20/20 [00:03<00:00,  6.24it/s]


Epoch 15: train_loss=0.8565 train_acc=69.49% | val_loss=0.8908 val_acc=69.04%
[SAVED] New best val_acc: 69.04% -> dalle_cifar10_best.pth


Epoch 16/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=70.3, loss=0.831]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.99it/s]


Epoch 16: train_loss=0.8305 train_acc=70.30% | val_loss=0.8715 val_acc=69.92%
[SAVED] New best val_acc: 69.92% -> dalle_cifar10_best.pth


Epoch 17/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.38it/s, acc=71.3, loss=0.804]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.23it/s]


Epoch 17: train_loss=0.8039 train_acc=71.34% | val_loss=0.8125 val_acc=71.80%
[SAVED] New best val_acc: 71.80% -> dalle_cifar10_best.pth


Epoch 18/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=72.1, loss=0.778]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.90it/s]


Epoch 18: train_loss=0.7777 train_acc=72.14% | val_loss=0.8176 val_acc=70.58%


Epoch 19/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.42it/s, acc=73.2, loss=0.751]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.26it/s]


Epoch 19: train_loss=0.7513 train_acc=73.16% | val_loss=0.8002 val_acc=72.12%
[SAVED] New best val_acc: 72.12% -> dalle_cifar10_best.pth


Epoch 20/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=74.1, loss=0.732]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.16it/s]


Epoch 20: train_loss=0.7321 train_acc=74.06% | val_loss=0.7967 val_acc=72.26%
[SAVED] New best val_acc: 72.26% -> dalle_cifar10_best.pth


Epoch 21/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.39it/s, acc=74.6, loss=0.716]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.76it/s]


Epoch 21: train_loss=0.7161 train_acc=74.57% | val_loss=0.8104 val_acc=72.28%
[SAVED] New best val_acc: 72.28% -> dalle_cifar10_best.pth


Epoch 22/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.43it/s, acc=75.5, loss=0.683]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.67it/s]


Epoch 22: train_loss=0.6829 train_acc=75.53% | val_loss=0.7926 val_acc=72.90%
[SAVED] New best val_acc: 72.90% -> dalle_cifar10_best.pth


Epoch 23/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.44it/s, acc=76.2, loss=0.669]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.61it/s]


Epoch 23: train_loss=0.6685 train_acc=76.23% | val_loss=0.7667 val_acc=74.10%
[SAVED] New best val_acc: 74.10% -> dalle_cifar10_best.pth


Epoch 24/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=77.2, loss=0.644]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.78it/s]


Epoch 24: train_loss=0.6443 train_acc=77.17% | val_loss=0.7760 val_acc=73.76%


Epoch 25/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.41it/s, acc=77.5, loss=0.629]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.71it/s]


Epoch 25: train_loss=0.6290 train_acc=77.47% | val_loss=0.7614 val_acc=74.16%
[SAVED] New best val_acc: 74.16% -> dalle_cifar10_best.pth


Epoch 26/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.43it/s, acc=78.5, loss=0.603]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.57it/s]


Epoch 26: train_loss=0.6029 train_acc=78.52% | val_loss=0.7798 val_acc=73.82%


Epoch 27/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.41it/s, acc=79.2, loss=0.585]
[val]: 100%|██████████| 20/20 [00:03<00:00,  6.01it/s]


Epoch 27: train_loss=0.5846 train_acc=79.21% | val_loss=0.7499 val_acc=74.38%
[SAVED] New best val_acc: 74.38% -> dalle_cifar10_best.pth


Epoch 28/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.44it/s, acc=80, loss=0.561]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.16it/s]


Epoch 28: train_loss=0.5606 train_acc=79.96% | val_loss=0.7294 val_acc=75.80%
[SAVED] New best val_acc: 75.80% -> dalle_cifar10_best.pth


Epoch 29/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.40it/s, acc=80.5, loss=0.541]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.33it/s]


Epoch 29: train_loss=0.5414 train_acc=80.47% | val_loss=0.7580 val_acc=75.08%


Epoch 30/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=81.7, loss=0.519]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.08it/s]


Epoch 30: train_loss=0.5193 train_acc=81.67% | val_loss=0.7306 val_acc=76.70%
[SAVED] New best val_acc: 76.70% -> dalle_cifar10_best.pth


Epoch 31/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.42it/s, acc=82.2, loss=0.498]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.07it/s]


Epoch 31: train_loss=0.4982 train_acc=82.17% | val_loss=0.7055 val_acc=76.76%
[SAVED] New best val_acc: 76.76% -> dalle_cifar10_best.pth


Epoch 32/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.38it/s, acc=82.5, loss=0.487]
[val]: 100%|██████████| 20/20 [00:02<00:00,  9.13it/s]


Epoch 32: train_loss=0.4875 train_acc=82.47% | val_loss=0.7053 val_acc=77.06%
[SAVED] New best val_acc: 77.06% -> dalle_cifar10_best.pth


Epoch 33/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=83.2, loss=0.467]
[val]: 100%|██████████| 20/20 [00:03<00:00,  5.83it/s]


Epoch 33: train_loss=0.4672 train_acc=83.24% | val_loss=0.7289 val_acc=76.34%


Epoch 34/60 [train]: 100%|██████████| 176/176 [00:29<00:00,  5.95it/s, acc=84.1, loss=0.445]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.95it/s]


Epoch 34: train_loss=0.4447 train_acc=84.08% | val_loss=0.7100 val_acc=77.12%
[SAVED] New best val_acc: 77.12% -> dalle_cifar10_best.pth


Epoch 35/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.33it/s, acc=84.6, loss=0.428]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.25it/s]


Epoch 35: train_loss=0.4281 train_acc=84.63% | val_loss=0.7321 val_acc=76.60%


Epoch 36/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.26it/s, acc=85.3, loss=0.413]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.45it/s]


Epoch 36: train_loss=0.4135 train_acc=85.34% | val_loss=0.7087 val_acc=77.74%
[SAVED] New best val_acc: 77.74% -> dalle_cifar10_best.pth


Epoch 37/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.33it/s, acc=85.9, loss=0.391]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.99it/s]


Epoch 37: train_loss=0.3906 train_acc=85.93% | val_loss=0.7063 val_acc=77.68%


Epoch 38/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.39it/s, acc=86.3, loss=0.379]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.35it/s]


Epoch 38: train_loss=0.3789 train_acc=86.33% | val_loss=0.7266 val_acc=77.54%


Epoch 39/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.32it/s, acc=87, loss=0.36]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.78it/s]


Epoch 39: train_loss=0.3596 train_acc=87.02% | val_loss=0.7302 val_acc=77.62%


Epoch 40/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.38it/s, acc=87.5, loss=0.349]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.97it/s]


Epoch 40: train_loss=0.3486 train_acc=87.48% | val_loss=0.7139 val_acc=78.56%
[SAVED] New best val_acc: 78.56% -> dalle_cifar10_best.pth


Epoch 41/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.40it/s, acc=88.1, loss=0.333]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.01it/s]


Epoch 41: train_loss=0.3330 train_acc=88.05% | val_loss=0.7197 val_acc=78.38%


Epoch 42/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.30it/s, acc=88.7, loss=0.315]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.65it/s]


Epoch 42: train_loss=0.3148 train_acc=88.72% | val_loss=0.7324 val_acc=78.44%


Epoch 43/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=89.2, loss=0.302]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.61it/s]


Epoch 43: train_loss=0.3020 train_acc=89.23% | val_loss=0.7047 val_acc=79.26%
[SAVED] New best val_acc: 79.26% -> dalle_cifar10_best.pth


Epoch 44/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.34it/s, acc=89.6, loss=0.293]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.79it/s]


Epoch 44: train_loss=0.2933 train_acc=89.58% | val_loss=0.7811 val_acc=77.70%


Epoch 45/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.27it/s, acc=89.9, loss=0.279]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.39it/s]


Epoch 45: train_loss=0.2791 train_acc=89.91% | val_loss=0.7685 val_acc=78.56%


Epoch 46/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.37it/s, acc=90.4, loss=0.268]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.24it/s]


Epoch 46: train_loss=0.2677 train_acc=90.35% | val_loss=0.7501 val_acc=78.60%


Epoch 47/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=90.9, loss=0.255]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.45it/s]


Epoch 47: train_loss=0.2550 train_acc=90.91% | val_loss=0.7571 val_acc=78.48%


Epoch 48/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.21it/s, acc=91.2, loss=0.247]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.54it/s]


Epoch 48: train_loss=0.2468 train_acc=91.19% | val_loss=0.7614 val_acc=78.38%


Epoch 49/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.34it/s, acc=91.7, loss=0.233]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.32it/s]


Epoch 49: train_loss=0.2326 train_acc=91.69% | val_loss=0.7703 val_acc=78.46%


Epoch 50/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=91.8, loss=0.229]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.33it/s]


Epoch 50: train_loss=0.2287 train_acc=91.83% | val_loss=0.7626 val_acc=78.80%


Epoch 51/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.28it/s, acc=92.1, loss=0.22]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.22it/s]


Epoch 51: train_loss=0.2202 train_acc=92.09% | val_loss=0.7611 val_acc=79.40%
[SAVED] New best val_acc: 79.40% -> dalle_cifar10_best.pth


Epoch 52/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.38it/s, acc=92.2, loss=0.219]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.17it/s]


Epoch 52: train_loss=0.2186 train_acc=92.19% | val_loss=0.7872 val_acc=78.68%


Epoch 53/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.36it/s, acc=92.5, loss=0.211]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.33it/s]


Epoch 53: train_loss=0.2109 train_acc=92.47% | val_loss=0.8018 val_acc=78.10%


Epoch 54/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.31it/s, acc=92.8, loss=0.206]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.71it/s]


Epoch 54: train_loss=0.2063 train_acc=92.77% | val_loss=0.7997 val_acc=78.66%


Epoch 55/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.31it/s, acc=92.8, loss=0.202]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.45it/s]


Epoch 55: train_loss=0.2017 train_acc=92.82% | val_loss=0.8066 val_acc=78.50%


Epoch 56/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.35it/s, acc=92.7, loss=0.202]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.40it/s]


Epoch 56: train_loss=0.2024 train_acc=92.73% | val_loss=0.8265 val_acc=78.62%


Epoch 57/60 [train]: 100%|██████████| 176/176 [00:28<00:00,  6.28it/s, acc=92.9, loss=0.199]
[val]: 100%|██████████| 20/20 [00:02<00:00,  7.00it/s]


Epoch 57: train_loss=0.1988 train_acc=92.88% | val_loss=0.7749 val_acc=78.92%


Epoch 58/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.34it/s, acc=93.1, loss=0.195]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.11it/s]


Epoch 58: train_loss=0.1949 train_acc=93.10% | val_loss=0.7912 val_acc=79.28%


Epoch 59/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.34it/s, acc=92.9, loss=0.198]
[val]: 100%|██████████| 20/20 [00:02<00:00,  8.13it/s]


Epoch 59: train_loss=0.1983 train_acc=92.94% | val_loss=0.7957 val_acc=79.26%


Epoch 60/60 [train]: 100%|██████████| 176/176 [00:27<00:00,  6.30it/s, acc=93.2, loss=0.193]
[val]: 100%|██████████| 20/20 [00:02<00:00,  6.76it/s]


Epoch 60: train_loss=0.1930 train_acc=93.21% | val_loss=0.7851 val_acc=79.90%
[SAVED] New best val_acc: 79.90% -> dalle_cifar10_best.pth


[test]: 100%|██████████| 40/40 [00:02<00:00, 15.29it/s]

Test Accuracy (best checkpoint): 80.44%





In [None]:
# CELL 7 — optionally: save history and download model
# If in Colab, you can download the file or move to Drive.
from google.colab import files
files.download(SAVE_PATH)    # prompts download in browser (Colab)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>