In [22]:
from datasets import load_dataset
import cv2
import albumentations
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.transforms import v2
from tqdm import tqdm

import os

In [23]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [24]:
ds = load_dataset("mrm8488/ImageNet1K-val", split="train", num_proc=4)

ds

Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['image', 'label'],
    num_rows: 50000
})

In [None]:
train_transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.Resize(size=(142, 142)),
        v2.RandomCrop(size=(128, 128)),
        v2.RandomHorizontalFlip(0.5),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

test_transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.Resize(size=(128, 128)),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

show_transforms = v2.Resize(size=(128, 128))


def train_transform(dataset):
    dataset["image"] = [train_transforms(image) for image in dataset["image"]]
    return dataset


def test_transform(dataset):
    dataset["image"] = [test_transforms(image) for image in dataset["image"]]
    return dataset


train_test = ds.train_test_split(test_size=0.1, seed=42)
train_val = train_test["train"].train_test_split(test_size=0.2, seed=42)

train_data = train_val["train"].map(train_transform, batched=True)
val_data = train_val["test"].map(test_transform, batched=True)
test_data = train_test["test"].map(test_transform, batched=True)

Map: 100%|██████████| 36000/36000 [02:54<00:00, 206.11 examples/s]
Map: 100%|██████████| 9000/9000 [01:05<00:00, 136.94 examples/s]
Map: 100%|██████████| 5000/5000 [00:48<00:00, 103.77 examples/s]


In [38]:
train_data.set_format("torch")
val_data.set_format("torch")
test_data.set_format("torch")

In [None]:
batch_size = 64

train_loader = DataLoader(
    train_data, batch_size=batch_size, num_workers=2, shuffle=True
)
val_loader = DataLoader(val_data, batch_size=batch_size, num_workers=2, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=2, shuffle=False)

In [47]:
class CNN(nn.Module):
    def __init__(self, num_classes=1000):
        super(CNN, self).__init__()

        # Feature Extractor
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=1),  # 62x62x64
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 30x30x64
            nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 30x30x192
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 14x14x192
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # 14x14x384
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 14x14x256
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 14x14x256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 6x6x256
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),  # FC1
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),  # FC2
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),  # FC3
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [48]:
cnn = CNN()
cnn.to(device)

optimizer = optim.Adam(cnn.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [51]:
num_epochs = 800

for epoch in range(num_epochs):
    cnn.train()
    running_loss = 0.0

    for data in tqdm(train_loader):
        images, labels = data.values()
        images, labels = images.to("cuda"), labels.to("cuda")

        optimizer.zero_grad()
        outputs = cnn(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}"
    )

    cnn.eval()
    total, correct = 0, 0
    running_val_loss = 0.0

    with torch.no_grad():
        for data in tqdm(val_loader):
            images, labels = data.values()
            images, labels = images.to("cuda"), labels.to("cuda")

            outputs = cnn(images)
            val_loss = criterion(outputs, labels)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_val_loss += val_loss.item()

    accuracy = 100 * correct / total
    print(
        f"Accuracy: {accuracy:.2f}% Validation Loss: {running_val_loss / len(val_loader):.4f}"
    )

    if epoch % 50 == 0:
        torch.save(
            {
                "epoch": epoch,
                "model": cnn.state_dict(),
                "optimizer": optimizer.state_dict(),
                "loss": loss,
                "val_loss": val_loss,
                "accuracy": accuracy,
            },
            f"epoch_{epoch}.pth",
        )

        print(f"Epoch #{epoch} saved")

100%|██████████| 563/563 [00:35<00:00, 16.07it/s]


Epoch [1/800], Loss: 6.9062


100%|██████████| 141/141 [00:10<00:00, 13.73it/s]


Accuracy: 0.04% Validation Loss: 6.9228
Epoch #0 saved


100%|██████████| 563/563 [00:36<00:00, 15.61it/s]


Epoch [2/800], Loss: 6.9060


100%|██████████| 141/141 [00:10<00:00, 13.68it/s]


Accuracy: 0.04% Validation Loss: 6.9240


100%|██████████| 563/563 [00:35<00:00, 16.03it/s]


Epoch [3/800], Loss: 6.9059


100%|██████████| 141/141 [00:10<00:00, 13.56it/s]


Accuracy: 0.04% Validation Loss: 6.9252


100%|██████████| 563/563 [00:34<00:00, 16.13it/s]


Epoch [4/800], Loss: 6.9058


100%|██████████| 141/141 [00:10<00:00, 13.78it/s]


Accuracy: 0.04% Validation Loss: 6.9261


100%|██████████| 563/563 [00:34<00:00, 16.14it/s]


Epoch [5/800], Loss: 6.9058


100%|██████████| 141/141 [00:10<00:00, 13.47it/s]


Accuracy: 0.04% Validation Loss: 6.9270


100%|██████████| 563/563 [00:34<00:00, 16.16it/s]


Epoch [6/800], Loss: 6.9057


100%|██████████| 141/141 [00:10<00:00, 13.74it/s]


Accuracy: 0.04% Validation Loss: 6.9278


100%|██████████| 563/563 [00:34<00:00, 16.12it/s]


Epoch [7/800], Loss: 6.9057


100%|██████████| 141/141 [00:10<00:00, 13.59it/s]


Accuracy: 0.04% Validation Loss: 6.9284


100%|██████████| 563/563 [00:36<00:00, 15.35it/s]


Epoch [8/800], Loss: 6.9056


100%|██████████| 141/141 [00:10<00:00, 13.13it/s]


Accuracy: 0.04% Validation Loss: 6.9290


100%|██████████| 563/563 [00:45<00:00, 12.50it/s]


Epoch [9/800], Loss: 6.9056


100%|██████████| 141/141 [00:16<00:00,  8.80it/s]


Accuracy: 0.04% Validation Loss: 6.9296


100%|██████████| 563/563 [00:52<00:00, 10.64it/s]


Epoch [10/800], Loss: 6.9056


100%|██████████| 141/141 [00:10<00:00, 13.53it/s]


Accuracy: 0.04% Validation Loss: 6.9300


100%|██████████| 563/563 [00:46<00:00, 12.12it/s]


Epoch [11/800], Loss: 6.9056


100%|██████████| 141/141 [00:11<00:00, 12.03it/s]


Accuracy: 0.04% Validation Loss: 6.9304


100%|██████████| 563/563 [00:49<00:00, 11.38it/s]


Epoch [12/800], Loss: 6.9056


100%|██████████| 141/141 [00:10<00:00, 13.58it/s]


Accuracy: 0.04% Validation Loss: 6.9307


100%|██████████| 563/563 [00:35<00:00, 15.98it/s]


Epoch [13/800], Loss: 6.9055


100%|██████████| 141/141 [00:10<00:00, 13.87it/s]


Accuracy: 0.04% Validation Loss: 6.9310


100%|██████████| 563/563 [00:34<00:00, 16.28it/s]


Epoch [14/800], Loss: 6.9056


100%|██████████| 141/141 [00:10<00:00, 13.83it/s]


Accuracy: 0.04% Validation Loss: 6.9313


 58%|█████▊    | 329/563 [00:21<00:15, 15.41it/s]


KeyboardInterrupt: 

In [None]:
loaded_checkpoint = torch.load("epoch_150.pth")
loaded_checkpoint

{'epoch': 150,
 'model': OrderedDict([('features.0.weight',
               tensor([[[[ 0.0638, -0.0079, -0.0126,  ...,  0.0037, -0.0312,  0.0663],
                         [ 0.0237,  0.0293, -0.0241,  ..., -0.0408,  0.0032, -0.0330],
                         [-0.0088,  0.0635,  0.0543,  ...,  0.0300,  0.0123,  0.0317],
                         ...,
                         [ 0.0058, -0.0663, -0.0533,  ..., -0.0347, -0.0385, -0.0591],
                         [-0.0441,  0.0703,  0.0091,  ...,  0.0714, -0.0022, -0.0239],
                         [-0.0693, -0.0021, -0.0688,  ...,  0.0445, -0.0057,  0.0206]],
               
                        [[ 0.0757,  0.0073,  0.0857,  ...,  0.0100, -0.0254, -0.0197],
                         [ 0.0709, -0.0133, -0.0502,  ..., -0.0007,  0.0292,  0.0344],
                         [-0.0375, -0.0295, -0.0616,  ..., -0.0034, -0.0106,  0.0724],
                         ...,
                         [-0.0455, -0.0110,  0.0553,  ..., -0.0502, -0.0664,  0.