In [None]:
# %% Init

import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils
import torch.utils.data
import torchvision
from torchvision import transforms as T
from tqdm import tqdm
import datasets

torch.set_printoptions(sci_mode=False, linewidth=120)


def setup_seed(seed=42):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


setup_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
# %% Define datasets and loaders

def get_datasets():
    setup_seed(42)

    image_size = 28
    transform = T.Compose([T.Resize(image_size), T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5])])

    cifar10 = datasets.load_dataset("cifar10")
    train_eval_split = cifar10["train"].train_test_split(test_size=0.1, stratify_by_column="label")
    cifar10 = datasets.DatasetDict(
        {"train": train_eval_split["train"], "eval": train_eval_split["test"], "test": cifar10["test"]}
    )
    cifar10 = cifar10.rename_column("img", "image")
    cifar10 = cifar10.cast_column("image", datasets.Image(mode="L"))
    cifar10 = cifar10.map(lambda sample: {"pixel_values": transform(sample["image"])})
    cifar10.set_format("pt", columns=["pixel_values"], output_all_columns=True)

    mnist = datasets.load_dataset("mnist")
    train_eval_split = mnist["train"].train_test_split(test_size=0.1, stratify_by_column="label")
    mnist = datasets.DatasetDict(
        {"train": train_eval_split["train"], "eval": train_eval_split["test"], "test": mnist["test"]}
    )
    mnist = mnist.map(lambda sample: {"pixel_values": transform(sample["image"])})
    mnist.set_format("pt", columns=["pixel_values"], output_all_columns=True)

    return cifar10, mnist


def collate_fn(examples):
    images = []
    labels = []
    for example in examples:
        images.append(example["pixel_values"])
        labels.append(example["label"])

    pixel_values = torch.stack(images)
    labels = torch.tensor(labels)
    return {"pixel_values": pixel_values, "labels": labels}


def get_data_loader(dataset):
    batch_size = 32

    loader = {}
    for split, data in dataset.items():
        loader[split] = torch.utils.data.DataLoader(data, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)

    return loader

In [None]:
# %% Load datasets

cifar10, mnist = get_datasets()

cifar10_loaders = get_data_loader(cifar10)
mnist_loaders = get_data_loader(mnist)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [None]:
# %% Train/eval loops

def train_one_epoch(model, criterion, optimizer, train_loader, device):
    model.train()

    train_loss = 0
    train_accuracy = 0

    for batch in tqdm(train_loader):
        inputs = batch["pixel_values"]
        labels = batch["labels"]
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y_pred_class = outputs.argmax(dim=1)
        batch_accuracy = (y_pred_class == labels).sum().item() / len(y_pred_class)
        train_accuracy += batch_accuracy

    train_loss = train_loss / len(train_loader)
    train_accuracy = train_accuracy / len(train_loader)

    return train_loss, train_accuracy


def validate(model, criterion, val_loader, device):
    model.eval()

    test_loss = 0.0
    test_accuracy = 0

    with torch.inference_mode():
        for batch in tqdm(val_loader):
            inputs = batch["pixel_values"]
            labels = batch["labels"]
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            y_pred_class = outputs.argmax(dim=1)
            test_accuracy += (y_pred_class == labels).sum().item() / len(y_pred_class)

    test_loss = test_loss / len(val_loader)
    test_accuracy = test_accuracy / len(val_loader)

    return test_loss, test_accuracy


def train_model(model, loaders: dict):
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    num_epochs = 5

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        train_loss, train_accuracy = train_one_epoch(model, criterion, optimizer, loaders["train"], device)

        val_loss, val_accuracy = validate(model, criterion, loaders["eval"], device)

        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    return model

In [None]:
# %% Train the base model (on cifar10)

model_base_cifar = torchvision.models.resnet18(num_classes=10)
# Make the model take images with 1 channel
model_base_cifar.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model_base_cifar = train_model(model_base_cifar, cifar10_loaders)

Epoch 1/5


100%|██████████| 1407/1407 [07:30<00:00,  3.12it/s]
100%|██████████| 157/157 [00:10<00:00, 14.66it/s]


Train Loss: 1.5575, Train Accuracy: 0.4458
Val Loss: 1.3916, Val Accuracy: 0.5165
Epoch 2/5


100%|██████████| 1407/1407 [07:04<00:00,  3.32it/s]
100%|██████████| 157/157 [00:10<00:00, 14.72it/s]


Train Loss: 1.1894, Train Accuracy: 0.5830
Val Loss: 1.2501, Val Accuracy: 0.5657
Epoch 3/5


100%|██████████| 1407/1407 [07:05<00:00,  3.31it/s]
100%|██████████| 157/157 [00:14<00:00, 10.82it/s]


Train Loss: 0.9948, Train Accuracy: 0.6516
Val Loss: 1.0830, Val Accuracy: 0.6320
Epoch 4/5


100%|██████████| 1407/1407 [08:21<00:00,  2.81it/s]
100%|██████████| 157/157 [00:12<00:00, 12.87it/s]


Train Loss: 0.8328, Train Accuracy: 0.7118
Val Loss: 1.0859, Val Accuracy: 0.6445
Epoch 5/5


100%|██████████| 1407/1407 [08:11<00:00,  2.86it/s]
100%|██████████| 157/157 [00:10<00:00, 14.56it/s]

Train Loss: 0.6947, Train Accuracy: 0.7595
Val Loss: 1.1716, Val Accuracy: 0.6399





In [None]:
# %% Train the base model (on mnist)

model_base_mnist = torchvision.models.resnet18(num_classes=10)
model_base_mnist.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model_base_mnist = train_model(model_base_mnist, mnist_loaders)

Epoch 1/5


100%|██████████| 1688/1688 [09:25<00:00,  2.98it/s]
100%|██████████| 188/188 [00:12<00:00, 15.31it/s]


Train Loss: 0.1585, Train Accuracy: 0.9533
Val Loss: 0.0596, Val Accuracy: 0.9819
Epoch 2/5


100%|██████████| 1688/1688 [08:55<00:00,  3.15it/s]
100%|██████████| 188/188 [00:11<00:00, 16.14it/s]


Train Loss: 0.0696, Train Accuracy: 0.9802
Val Loss: 0.0629, Val Accuracy: 0.9820
Epoch 3/5


100%|██████████| 1688/1688 [08:39<00:00,  3.25it/s]
100%|██████████| 188/188 [00:11<00:00, 16.73it/s]


Train Loss: 0.0561, Train Accuracy: 0.9843
Val Loss: 0.0682, Val Accuracy: 0.9822
Epoch 4/5


100%|██████████| 1688/1688 [08:35<00:00,  3.27it/s]
100%|██████████| 188/188 [00:11<00:00, 16.59it/s]


Train Loss: 0.0420, Train Accuracy: 0.9873
Val Loss: 0.0865, Val Accuracy: 0.9764
Epoch 5/5


100%|██████████| 1688/1688 [09:54<00:00,  2.84it/s]
100%|██████████| 188/188 [00:15<00:00, 12.11it/s]

Train Loss: 0.0339, Train Accuracy: 0.9895
Val Loss: 0.0411, Val Accuracy: 0.9904





In [None]:
# %% Fine-tune the model on new dataset (mnist)

import copy

model_cifar_to_mnist = copy.deepcopy(model_base_cifar)

# Prepare model for fine tuning

for param in model_cifar_to_mnist.parameters():
    param.requires_grad = False

model_cifar_to_mnist.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

model_base_mnist = train_model(model_base_mnist, mnist_loaders)

Epoch 1/5


100%|██████████| 1688/1688 [10:44<00:00,  2.62it/s]
100%|██████████| 188/188 [00:14<00:00, 13.39it/s]


Train Loss: 0.0315, Train Accuracy: 0.9909
Val Loss: 0.0497, Val Accuracy: 0.9882
Epoch 2/5


 63%|██████▎   | 1069/1688 [06:45<03:54,  2.63it/s]


KeyboardInterrupt: 

In [None]:
from torchvision.models import ResNet18_Weights

model = torchvision.models.resnet18(weights="IMAGENET1K_V1")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/citizen2/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 10.8MB/s]


In [None]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [None]:
model.fc.params()

AttributeError: 'Linear' object has no attribute 'params'

In [None]:
model.fc.params

AttributeError: 'Linear' object has no attribute 'params'

In [None]:
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [None]:
model.fc.parameters()

<generator object Module.parameters at 0x178355af0>

In [None]:
list(model.fc.parameters())

[Parameter containing:
 tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
         [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
         [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
         ...,
         [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
         [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
         [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]], requires_grad=True),
 Parameter containing:
 tensor([    -0.0026,      0.0030,      0.0007,     -0.0269,      0.0064,      0.0133,     -0.0112,      0.0206,
             -0.0036,     -0.0123,     -0.0126,     -0.0072,     -0.0193,     -0.0250,     -0.0119,     -0.0083,
             -0.0096,     -0.0167,      0.0092,     -0.0154,      0.0071,      0.0307,      0.0132,     -0.0078,
              0.0047,      0.0112,      0.0159,     -0.0167,     -0.0010,     -0.0037,      0.0065,     -0.0120,
              0.0090,     -0.0008,      

In [None]:
model.fc.params

AttributeError: 'Linear' object has no attribute 'params'

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")

params = list(model.fc.parameters())
params


[Parameter containing:
 tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
         [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
         [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
         ...,
         [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
         [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
         [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]], requires_grad=True),
 Parameter containing:
 tensor([    -0.0026,      0.0030,      0.0007,     -0.0269,      0.0064,      0.0133,     -0.0112,      0.0206,
             -0.0036,     -0.0123,     -0.0126,     -0.0072,     -0.0193,     -0.0250,     -0.0119,     -0.0083,
             -0.0096,     -0.0167,      0.0092,     -0.0154,      0.0071,      0.0307,      0.0132,     -0.0078,
              0.0047,      0.0112,      0.0159,     -0.0167,     -0.0010,     -0.0037,      0.0065,     -0.0120,
              0.0090,     -0.0008,      

In [None]:
params[0]

Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
        [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
        [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
        ...,
        [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
        [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
        [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]], requires_grad=True)

In [None]:
params[0].shape

torch.Size([1000, 512])

In [None]:
params[1].shape

torch.Size([1000])

In [None]:
params[2].shape

IndexError: list index out of range

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")

params = list(model.fc.parameters())
params[0]


Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
        [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
        [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
        ...,
        [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
        [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
        [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]], requires_grad=True)

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")

params = list(model.fc.parameters())
print(params[0])
params[0].shape

Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
        [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
        [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
        ...,
        [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
        [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
        [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]], requires_grad=True)


torch.Size([1000, 512])

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1", num_classes=10)

params = list(model.fc.parameters())
print(params[0])
params[0].shape

ValueError: The parameter 'num_classes' expected value 1000 but got 10 instead.

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc = Linear(in_features=512, out_features=10, bias=True)

NameError: name 'Linear' is not defined

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

In [None]:
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

In [None]:
model.fc.parameters()

<generator object Module.parameters at 0x178864ac0>

In [None]:
list(model.fc.parameters())

[Parameter containing:
 tensor([[ 0.0171, -0.0432,  0.0253,  ..., -0.0237, -0.0168, -0.0085],
         [ 0.0425, -0.0218,  0.0109,  ...,  0.0075, -0.0185, -0.0160],
         [-0.0260,  0.0099, -0.0143,  ...,  0.0331,  0.0127, -0.0039],
         ...,
         [-0.0023,  0.0315,  0.0395,  ...,  0.0030, -0.0348, -0.0373],
         [ 0.0030,  0.0270,  0.0397,  ...,  0.0191,  0.0062, -0.0079],
         [ 0.0080,  0.0401,  0.0224,  ...,  0.0193, -0.0369,  0.0291]], requires_grad=True),
 Parameter containing:
 tensor([ 0.0080,  0.0376,  0.0237,  0.0279,  0.0177, -0.0375, -0.0344,  0.0234,  0.0192,  0.0065], requires_grad=True)]

In [None]:
criterion = torch.nn.CrossEntropyLoss()

validate(model, criterion, cifar10_loaders["eval"], "cpu")

  0%|          | 0/157 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 1, 28, 28] to have 3 channels, but got 1 channels instead