In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/cell_data.zip -d /content

In [None]:
from glob import glob
images = glob("/content/cell_data/*")
print(len(images))

10365


In [None]:
import os

for i in range(len(images)):
  image = images[i]
  cell_line = image.split("_")[1].split("/")[1]
  if not os.path.exists(f"/content/cell_data/{cell_line}"):
    os.makedirs(f"/content/cell_data/{cell_line}")
  os.replace(image, f"/content/cell_data/{cell_line}/{i}.png")

In [None]:
import numpy as np
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import v2

num_classes = 6

dataset = ImageFolder(
    root="/content/cell_data",
    transform=v2.Compose([
        v2.Resize(size=(400, 400)),
        v2.RandomRotation((-180, 180)),
        v2.RandomHorizontalFlip(p=0.5),
        v2.ToTensor(),
        v2.Normalize((0.5,), (0.5,))
    ]))

train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.93, 0.05, 0.02])

print(f"""
Train dataset: {len(train_dataset)}
Validation dataset: {len(validation_dataset)}
Test dataset: {len(test_dataset)}
""")


Train dataset: 9640
Validation dataset: 518
Test dataset: 207





In [None]:
import matplotlib.pyplot as plt
import torchvision.models as models
from torch import nn, optim
from torch.nn import BCEWithLogitsLoss
from datetime import datetime

model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Freeze the parameters of the pre-trained layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the parameters of the last few layers for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

In [None]:
%pip install wandb

In [None]:
import wandb

wandb.login()

EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.001
MOMENTUM = 0.9

wandb.init(
  project="cell-classification",
  name="resnet-experiment-1",
  config={
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "optimizer": "SGD",
    "learning_rate": LEARNING_RATE,
    "momentum": MOMENTUM,
    "architecture": "ResNet50",
  })

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mskareerik55[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

def calculate_accuracy(outputs, labels):
  _, predictions = outputs.max(1)
  num_correct = (predictions == labels).sum()
  num_samples = predictions.size(0)
  return num_correct / num_samples

def train_one_epoch(epoch_index):
    running_loss = 0.
    running_acc = 0.

    for i, data in enumerate(training_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        running_acc += calculate_accuracy(outputs, labels)

    return running_loss / (i + 1), running_acc / (i + 1)

In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

epoch_number = 0
best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    model.train(True)
    avg_loss, avg_acc = train_one_epoch(epoch_number)

    model.eval()

    running_vloss = 0.
    running_vacc = 0.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss
            running_vacc += calculate_accuracy(voutputs, vlabels)

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)

    wandb.log({
        "loss": avg_loss,
        "val_loss": avg_vloss,
        "acc": avg_acc,
        "val_acc": avg_vacc,
    })
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    print('ACC train {} valid {}'.format(avg_acc, avg_vacc))
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
LOSS train 1.371936714017628 valid 0.8402350544929504
ACC train 0.4396316409111023 valid 0.7274305820465088
EPOCH 2:
LOSS train 0.5625608658159016 valid 0.35298892855644226
ACC train 0.8442052602767944 valid 0.9184027910232544
EPOCH 3:
LOSS train 0.2687852250424442 valid 0.1801905781030655
ACC train 0.9480960369110107 valid 0.9704861044883728
EPOCH 4:
LOSS train 0.15547217117832196 valid 0.10575270652770996
ACC train 0.9756209254264832 valid 0.9809027910232544
EPOCH 5:
LOSS train 0.10010102544201921 valid 0.0674423947930336
ACC train 0.9839611053466797 valid 0.9895833134651184
EPOCH 6:
LOSS train 0.06931579855597572 valid 0.047333527356386185
ACC train 0.9908940196037292 valid 0.9895833134651184
EPOCH 7:
LOSS train 0.05212571339940788 valid 0.03528901934623718
ACC train 0.9953435659408569 valid 1.0
EPOCH 8:


KeyboardInterrupt: 

In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▆▇████
loss,█▄▂▂▁▁▁
val_acc,▁▆▇████
val_loss,█▄▂▂▁▁▁

0,1
acc,0.99534
loss,0.05213
val_acc,1.0
val_loss,0.03529


In [None]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')

    model.train()

check_accuracy(test_loader, model)

Got 207 / 207 with accuracy 100.00%
