In [12]:
from datasets import load_dataset
import pandas as pd

from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.get_default_device()
img_to_tensor = transforms.ToTensor()
generator = torch.Generator(device)

batch_size = 100
n_classes = 10
embed_dim = 64

In [13]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return img_to_tensor(self.data[idx]["img"]).to(device), torch.tensor(self.data[idx]["label"])
    
ds = load_dataset("uoft-cs/cifar10")

train_dl = DataLoader(ImageDataset(ds["train"]),batch_size=batch_size, shuffle=True, drop_last=True, generator=generator)
test_dl = DataLoader(ImageDataset(ds["test"]), batch_size=batch_size, shuffle=False, generator=generator)

In [14]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self,x):
        return self.layers(x) + self.shortcut(x)


class ResidualNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_head = nn.Sequential(
            nn.Conv2d(3, embed_dim//4, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim//4),
            nn.ReLU()
        )
        
        self.layer1 = ResidualBlock(embed_dim//4, embed_dim//2, stride=1)
        self.layer2 = ResidualBlock(embed_dim//2, embed_dim, stride=2)
        self.layer3 = ResidualBlock(embed_dim, embed_dim, stride=2)
        self.layer4 = ResidualBlock(embed_dim, embed_dim, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.out_head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        x = self.in_head(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.out_head(x)

In [15]:
epochs = 50 
model = ResidualNetwork()
opt = torch.optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()
history = []

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 229,786


In [16]:

for epoch in range(epochs):
    total_loss = 0
    correct = 0

    for input_batch, label_batch in tqdm(train_dl):
        opt.zero_grad()

        pred_batch = model(input_batch)
        loss = loss_fn(pred_batch, label_batch)
        loss.backward()

        opt.step()

        with torch.no_grad():
            total_loss += loss.item()
            for i,label in enumerate(label_batch):
                if pred_batch[i,label.item()] == pred_batch[i].max():
                    correct+=1

    
    train_loss = total_loss/ len(train_dl)
    train_acc = correct / (len(train_dl) * batch_size)
    
    model.eval()
    test_loss, test_correct, m = 0, 0, 0

    with torch.no_grad():
        for input_batch, label_batch in test_dl:
            logits = model(input_batch)
            loss = loss_fn(logits, label_batch)

            test_loss += loss.item() * input_batch.size(0)
            preds = logits.argmax(dim=1)
            test_correct += (preds == label_batch).sum().item()
            m += input_batch.size(0)

    test_loss /= m
    test_acc = test_correct / m

    metrics = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'test_loss': test_loss,
        'test_acc': test_acc
    }
    print(metrics, "\n\n")
    # ---- Log metrics ----
    history.append(metrics)

    history_df = pd.DataFrame(history)
    history_df.to_csv("./history_resnet.csv", index=False)


    torch.save(model.state_dict(), "./resnet4_model.pth")

100%|██████████| 500/500 [00:29<00:00, 16.95it/s]


{'epoch': 1, 'train_loss': 1.2978455173969268, 'train_acc': 0.52948, 'test_loss': 1.15022882938385, 'test_acc': 0.579} 




100%|██████████| 500/500 [00:23<00:00, 21.71it/s]


{'epoch': 2, 'train_loss': 1.0609982424974442, 'train_acc': 0.61752, 'test_loss': 0.9472595107555389, 'test_acc': 0.662} 




100%|██████████| 500/500 [00:22<00:00, 21.89it/s]


{'epoch': 3, 'train_loss': 0.8692649726867676, 'train_acc': 0.69018, 'test_loss': 0.8437076830863952, 'test_acc': 0.6959} 




100%|██████████| 500/500 [00:23<00:00, 21.71it/s]


{'epoch': 4, 'train_loss': 0.7578418715596199, 'train_acc': 0.73108, 'test_loss': 0.772643762230873, 'test_acc': 0.7287} 




100%|██████████| 500/500 [00:22<00:00, 21.76it/s]


{'epoch': 5, 'train_loss': 0.6703017339706421, 'train_acc': 0.76394, 'test_loss': 0.7064777535200119, 'test_acc': 0.7528} 




100%|██████████| 500/500 [00:22<00:00, 21.83it/s]


{'epoch': 6, 'train_loss': 0.6095691410303116, 'train_acc': 0.78594, 'test_loss': 0.7292971050739289, 'test_acc': 0.7509} 




100%|██████████| 500/500 [00:23<00:00, 21.65it/s]


{'epoch': 7, 'train_loss': 0.5577063520550728, 'train_acc': 0.80298, 'test_loss': 0.735875244140625, 'test_acc': 0.7509} 




100%|██████████| 500/500 [00:23<00:00, 21.68it/s]


{'epoch': 8, 'train_loss': 0.5177265429496765, 'train_acc': 0.81758, 'test_loss': 0.7301773843169213, 'test_acc': 0.7529} 




100%|██████████| 500/500 [00:23<00:00, 21.71it/s]


{'epoch': 9, 'train_loss': 0.47161101779341696, 'train_acc': 0.83352, 'test_loss': 0.7003457954525948, 'test_acc': 0.7586} 




100%|██████████| 500/500 [00:23<00:00, 21.64it/s]


{'epoch': 10, 'train_loss': 0.4363739545941353, 'train_acc': 0.84524, 'test_loss': 0.7002352878451348, 'test_acc': 0.7695} 




100%|██████████| 500/500 [00:23<00:00, 21.71it/s]


{'epoch': 11, 'train_loss': 0.39615178617835045, 'train_acc': 0.85962, 'test_loss': 0.7362655806541443, 'test_acc': 0.7726} 




100%|██████████| 500/500 [00:23<00:00, 21.63it/s]


{'epoch': 12, 'train_loss': 0.3719296853840351, 'train_acc': 0.86748, 'test_loss': 0.7292077508568764, 'test_acc': 0.7742} 




100%|██████████| 500/500 [00:22<00:00, 21.94it/s]


{'epoch': 13, 'train_loss': 0.33113251033425334, 'train_acc': 0.88066, 'test_loss': 0.8224681374430657, 'test_acc': 0.7554} 




100%|██████████| 500/500 [00:22<00:00, 22.11it/s]


{'epoch': 14, 'train_loss': 0.31220801240205764, 'train_acc': 0.8876, 'test_loss': 0.8128524458408356, 'test_acc': 0.7574} 




100%|██████████| 500/500 [00:22<00:00, 22.01it/s]


{'epoch': 15, 'train_loss': 0.29065942661464217, 'train_acc': 0.89476, 'test_loss': 0.8040643614530564, 'test_acc': 0.7686} 




100%|██████████| 500/500 [00:22<00:00, 22.14it/s]


{'epoch': 16, 'train_loss': 0.265937347099185, 'train_acc': 0.905, 'test_loss': 0.8295180919766426, 'test_acc': 0.7749} 




100%|██████████| 500/500 [00:22<00:00, 22.09it/s]


{'epoch': 17, 'train_loss': 0.23797395367175342, 'train_acc': 0.91296, 'test_loss': 0.8677560076117515, 'test_acc': 0.7619} 




100%|██████████| 500/500 [00:22<00:00, 22.02it/s]


{'epoch': 18, 'train_loss': 0.21702010571956634, 'train_acc': 0.92212, 'test_loss': 0.9502158200740815, 'test_acc': 0.7579} 




100%|██████████| 500/500 [00:22<00:00, 21.99it/s]


{'epoch': 19, 'train_loss': 0.2110704362541437, 'train_acc': 0.92424, 'test_loss': 0.9888644459843635, 'test_acc': 0.758} 




100%|██████████| 500/500 [00:22<00:00, 21.95it/s]


{'epoch': 20, 'train_loss': 0.1929065525829792, 'train_acc': 0.93042, 'test_loss': 0.9706305411458015, 'test_acc': 0.7621} 




100%|██████████| 500/500 [00:22<00:00, 22.02it/s]


{'epoch': 21, 'train_loss': 0.17716267386823892, 'train_acc': 0.93614, 'test_loss': 0.9875084179639816, 'test_acc': 0.7647} 




100%|██████████| 500/500 [00:22<00:00, 22.07it/s]


{'epoch': 22, 'train_loss': 0.16579891206324102, 'train_acc': 0.93932, 'test_loss': 1.0589040327072143, 'test_acc': 0.7662} 




100%|██████████| 500/500 [00:22<00:00, 22.09it/s]


{'epoch': 23, 'train_loss': 0.16213388059288264, 'train_acc': 0.94078, 'test_loss': 1.1077380716800689, 'test_acc': 0.7601} 




100%|██████████| 500/500 [00:22<00:00, 22.12it/s]


{'epoch': 24, 'train_loss': 0.14554912316799165, 'train_acc': 0.94636, 'test_loss': 1.1399243557453156, 'test_acc': 0.7541} 




100%|██████████| 500/500 [00:22<00:00, 22.10it/s]


{'epoch': 25, 'train_loss': 0.14932419420406223, 'train_acc': 0.9464, 'test_loss': 1.0811153423786164, 'test_acc': 0.7613} 




100%|██████████| 500/500 [00:22<00:00, 22.01it/s]


{'epoch': 26, 'train_loss': 0.12932989673316478, 'train_acc': 0.95362, 'test_loss': 1.1050321638584137, 'test_acc': 0.7695} 




100%|██████████| 500/500 [00:22<00:00, 22.07it/s]


{'epoch': 27, 'train_loss': 0.13256687539443374, 'train_acc': 0.95276, 'test_loss': 1.0784189608693122, 'test_acc': 0.7631} 




100%|██████████| 500/500 [00:22<00:00, 22.12it/s]


{'epoch': 28, 'train_loss': 0.13103545020520688, 'train_acc': 0.9517, 'test_loss': 1.0561434170603752, 'test_acc': 0.7675} 




100%|██████████| 500/500 [00:22<00:00, 21.99it/s]


{'epoch': 29, 'train_loss': 0.10744178697094321, 'train_acc': 0.96212, 'test_loss': 1.1781392571330072, 'test_acc': 0.7588} 




100%|██████████| 500/500 [00:22<00:00, 21.92it/s]


{'epoch': 30, 'train_loss': 0.11353918751887977, 'train_acc': 0.95906, 'test_loss': 1.2160244572162628, 'test_acc': 0.762} 




100%|██████████| 500/500 [00:22<00:00, 21.87it/s]


{'epoch': 31, 'train_loss': 0.11027528399042785, 'train_acc': 0.96124, 'test_loss': 1.2141996771097183, 'test_acc': 0.7498} 




100%|██████████| 500/500 [00:22<00:00, 22.03it/s]


{'epoch': 32, 'train_loss': 0.11112531591765583, 'train_acc': 0.96074, 'test_loss': 1.2462581476569177, 'test_acc': 0.7612} 




100%|██████████| 500/500 [00:22<00:00, 22.09it/s]


{'epoch': 33, 'train_loss': 0.09530816930904984, 'train_acc': 0.96564, 'test_loss': 1.2581138491630555, 'test_acc': 0.7576} 




100%|██████████| 500/500 [00:22<00:00, 22.04it/s]


{'epoch': 34, 'train_loss': 0.09941168914735317, 'train_acc': 0.96382, 'test_loss': 1.2483202636241912, 'test_acc': 0.7578} 




100%|██████████| 500/500 [00:22<00:00, 21.95it/s]


{'epoch': 35, 'train_loss': 0.10353179510310292, 'train_acc': 0.9628, 'test_loss': 1.2505984136462212, 'test_acc': 0.7538} 




100%|██████████| 500/500 [00:22<00:00, 21.96it/s]


{'epoch': 36, 'train_loss': 0.09574861910752952, 'train_acc': 0.9666, 'test_loss': 1.2014009547233582, 'test_acc': 0.7622} 




100%|██████████| 500/500 [00:22<00:00, 21.94it/s]


{'epoch': 37, 'train_loss': 0.09237931445799769, 'train_acc': 0.96702, 'test_loss': 1.2663970679044723, 'test_acc': 0.759} 




100%|██████████| 500/500 [00:22<00:00, 22.02it/s]


{'epoch': 38, 'train_loss': 0.09163041462004184, 'train_acc': 0.96738, 'test_loss': 1.2419514772295952, 'test_acc': 0.76} 




100%|██████████| 500/500 [00:22<00:00, 22.10it/s]


{'epoch': 39, 'train_loss': 0.0887056564129889, 'train_acc': 0.96872, 'test_loss': 1.2916370475292205, 'test_acc': 0.7501} 




100%|██████████| 500/500 [00:22<00:00, 22.16it/s]


{'epoch': 40, 'train_loss': 0.08139373925514519, 'train_acc': 0.97164, 'test_loss': 1.3141795018315314, 'test_acc': 0.7619} 




100%|██████████| 500/500 [00:22<00:00, 22.07it/s]


{'epoch': 41, 'train_loss': 0.08636011333204806, 'train_acc': 0.96966, 'test_loss': 1.3764675283432006, 'test_acc': 0.7592} 




100%|██████████| 500/500 [00:22<00:00, 22.13it/s]


{'epoch': 42, 'train_loss': 0.08225116002559661, 'train_acc': 0.97074, 'test_loss': 1.1986731886863708, 'test_acc': 0.7629} 




100%|██████████| 500/500 [00:22<00:00, 22.09it/s]


{'epoch': 43, 'train_loss': 0.08302330714091659, 'train_acc': 0.9713, 'test_loss': 1.2355385828018188, 'test_acc': 0.7679} 




100%|██████████| 500/500 [00:22<00:00, 22.04it/s]


{'epoch': 44, 'train_loss': 0.07349813577719033, 'train_acc': 0.97444, 'test_loss': 1.301476318538189, 'test_acc': 0.7716} 




100%|██████████| 500/500 [00:22<00:00, 22.01it/s]


{'epoch': 45, 'train_loss': 0.07975638332311064, 'train_acc': 0.97126, 'test_loss': 1.3427917063236237, 'test_acc': 0.7661} 




100%|██████████| 500/500 [00:22<00:00, 22.16it/s]


{'epoch': 46, 'train_loss': 0.07776935051381588, 'train_acc': 0.9728, 'test_loss': 1.35319864153862, 'test_acc': 0.7514} 




100%|██████████| 500/500 [00:22<00:00, 22.06it/s]


{'epoch': 47, 'train_loss': 0.07596011842414736, 'train_acc': 0.97386, 'test_loss': 1.2606508812308312, 'test_acc': 0.7735} 




100%|██████████| 500/500 [00:22<00:00, 22.01it/s]


{'epoch': 48, 'train_loss': 0.06985008442495018, 'train_acc': 0.97536, 'test_loss': 1.4041248989105224, 'test_acc': 0.7626} 




100%|██████████| 500/500 [00:22<00:00, 21.96it/s]


{'epoch': 49, 'train_loss': 0.06270915338210761, 'train_acc': 0.9776, 'test_loss': 1.433794322013855, 'test_acc': 0.7506} 




100%|██████████| 500/500 [00:23<00:00, 20.92it/s]


{'epoch': 50, 'train_loss': 0.07326164066977799, 'train_acc': 0.974, 'test_loss': 1.3633489269018173, 'test_acc': 0.7557} 




In [17]:
import plotly.express as px

history_df = pd.read_csv("./history_resnet.csv")

fig = px.line(history_df, x="epoch", y=["train_loss", "test_loss"])
fig.show()

fig = px.line(history_df, x="epoch", y=["train_acc", "test_acc"])
fig.show()