# Imports

In [1]:
import os
import time

from PIL import Image

import torch
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from torchvision import datasets, transforms

import timm

# Function Definitions

In [2]:
def train_model(epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir, save_every, print_every):
    hist_train_loss = []
    hist_valid_loss = []
    hist_train_accs = []
    hist_valid_accs = []

    best_accuracy = 0.0

    iters_per_epoch = len(train_loader)

    for epoch in range(epochs):
        train_corr = 0
        valid_corr = 0
        batch_corr = 0

        for i, (X_train, y_train) in enumerate(train_loader):
            X_train, y_train = X_train.to(device), y_train.to(device)

            train_pred = model(X_train)
            train_loss = criterion(train_pred, y_train)

            train_predicted = torch.max(train_pred.data, 1)[1]
            batch_corr = (train_predicted == y_train).sum()
            train_corr += batch_corr

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            if (i + 1) % print_every == 0:
                print(f'Iteration {i + 1}/{iters_per_epoch} complete; Loss: {train_loss.item()}')
                
            # checkpoint saving based on iterations
            if (i + 1) % save_every == 0:
                checkpoint_path = os.path.join(save_dir, f'epoch-{epoch}_iter-{i + 1}_checkpoint.pth')
                torch.save({
                    'epoch': epoch, 
                    'iteration': i + 1, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'loss': train_loss.item()
                }, checkpoint_path)
                print(f'Checkpoint saved at {checkpoint_path} after {i + 1}/{iters_per_epoch} iterations of epoch {epoch}')

        train_accuracy = train_corr.item() / len(train_loader.dataset)

        hist_train_loss.append(train_loss.item())
        hist_train_accs.append(train_accuracy)

        with torch.no_grad():
            for X_valid, y_valid in valid_loader:
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)

                valid_pred = model(X_valid)

                valid_predicted = torch.max(valid_pred.data, 1)[1]
                valid_corr += (valid_predicted == y_valid).sum()

        valid_accuracy = valid_corr.item() / len(valid_loader.dataset)
        valid_loss = criterion(valid_pred, y_valid)

        hist_valid_loss.append(valid_loss.item())
        hist_valid_accs.append(valid_accuracy)

        print(
            f'[epoch: {epoch}]\n', 
            f'- train loss: {train_loss.item():}, train accuracy: {train_accuracy}\n',
            f'- valid loss: {valid_loss.item()}, valid accuracy: {valid_accuracy}'
        )

        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'hist_train_loss': hist_train_loss, 
                'hist_train_accs': hist_train_accs,
                'hist_valid_loss': hist_valid_loss, 
                'hist_valid_accs': hist_valid_accs
            }, os.path.join(save_dir, f'best_model.pth'))
            print(f'- New best model saves with accuracy {valid_accuracy:.4f}%')

# Main Excecution

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.manual_seed(42) if device.type == 'cuda' else torch.manual_seed(42)
print(f'Using device: {device}')

Using device: cuda


In [4]:
# PREPARING THE DATASET

# data_path = '/scratch/ad5497/data/ftp.ebi.ac.uk/pub/databases/IDR/idr0016-wawer-bioactivecompoundprofiling/2016-01-19-screens-bbbc022'
data_path = '../data/pretraining-data'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=data_path, transform=transform)

dataset_indices_path = '../data/pretraining-dataset-indices.pth'
indices = torch.load(dataset_indices_path)

train_dataset = Subset(dataset, indices['train_indices'])
valid_dataset = Subset(dataset, indices['valid_indices'])
test_dataset = Subset(dataset, indices['test_indices'])

bs = 64
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

In [None]:
# CREATING AND TRAINING THE ViT MODEL

model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(dataset.classes))
model = model.to(device)

epochs = 10

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

save_dir = 'models'
os.makedirs(save_dir, exist_ok=True)
save_every = 90
print_every = 30

initial = time.time()
train_model(epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir, save_every, print_every)
final = time.time()

print(f'Total time for {epochs} epochs of {len(train_loader)} iterations: {final - initial:.4f} seconds')

  return F.conv2d(input, weight, bias, self.stride,


Iteration 30/3780 complete; Loss: 1.5932883024215698
Iteration 60/3780 complete; Loss: 1.5925735235214233
Iteration 90/3780 complete; Loss: 1.4878363609313965
Checkpoint saved at models/epoch-0_iter-90_checkpoint.pth after 90/3780 iterations of epoch 0
Iteration 120/3780 complete; Loss: 0.9722626209259033
Iteration 150/3780 complete; Loss: 1.0215811729431152
Iteration 180/3780 complete; Loss: 0.9535061120986938
Checkpoint saved at models/epoch-0_iter-180_checkpoint.pth after 180/3780 iterations of epoch 0
Iteration 210/3780 complete; Loss: 0.8120726346969604
Iteration 240/3780 complete; Loss: 1.5471727848052979
Iteration 270/3780 complete; Loss: 0.9804590344429016
Checkpoint saved at models/epoch-0_iter-270_checkpoint.pth after 270/3780 iterations of epoch 0
Iteration 300/3780 complete; Loss: 1.0825762748718262
Iteration 330/3780 complete; Loss: 0.9063102602958679
Iteration 360/3780 complete; Loss: 0.8838392496109009
Checkpoint saved at models/epoch-0_iter-360_checkpoint.pth after 360/