In [4]:
import numpy as np
from tqdm import tqdm
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.datasets import ImageFolder
from torcheval.metrics.functional import binary_confusion_matrix 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# metrics 

# tp = cm[1][1] 
# fp = cm[0][1] 
# fn = cm[1][0]
# tn = cm[0][0]

def precision(cm:np.ndarray) -> float:
    tp = cm[1][1] 
    fp = cm[0][1]
    return tp / (tp + fp)

def recall(cm:np.ndarray) -> float:
    tp = cm[1][1] 
    fn = cm[1][0]
    return tp / (tp + fn)

def accuracy(cm:np.ndarray) -> float:
    tp = cm[1][1]
    tn = cm[0][0]
    return (tp + tn) / cm.sum()

def f_score(cm:np.ndarray, factor: int=1) -> float:
    tp = cm[1][1]
    fp = cm[0][1]
    fn = cm[1][0]
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f_score = (1 + factor ** 2) * ((precision * recall) / (factor * precision + recall))
    return f_score

# torch workflow
def create_dataloader(dataset:str,
                      batch_size: int=16,
                      pin_memory: bool=True,
                      shuffle: bool=True) -> torch.utils.data.DataLoader:

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=pin_memory
    )    
    
    return loader   

def train_step(model: nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: nn.Module,
              optimizer: torch.optim.Optimizer,
              device: str,
              validation: bool=False) -> float:
    loss = 0
    if not validation:
        
        model.train()
        
        for X, y in loader:
            
            X, y = X.to(device), y.loat().to(device)
    
            logits = model(X)#.squeeze()
            preds = logits.sigmoid().round()
    
            loss_batch = loss_fn(logits, y)
            loss += loss_batch.item()
            cm = binary_confusion_matrix(preds, y.long()).detach().cpu().numpy()
            
            optimizer.zero_grad()
            loss_batch.backward()
            optimizer.step()
        
    if validation:
        
        model.eval()
        with torch.inference_mode():

            for X, y in loader:
                
            X, y = X.to(device), y.float().to(device)

            logits = model(X)#.squeeze()
            preds = logints.sigmoid().round()

            loss_batch = loss_fn(logits, y)
            loss += loss_batch
            cm = binary_confusion_matrix(preds, y.long()).detach().cpu().numpy()

            
    return loss, cm

def freeze_model(model):
  for param in model.parameters():
    param.requires_grad = False


In [6]:
# download model and weights
weights = ResNet18_Weights.DEFAULT
transforms_auto = weights.transforms()

In [None]:
path_train = '../png_roi/train_images'
path_test = '../png_roi/test_images'

dataset_train = ImageFolder(path_train, transform=transforms_auto)
dataset_test = ImageFolder(path_test, transform=transforms_auto)

loader_train = create_dataloader(dataset_train)
loader_test = create_dataloader(dataset_test, shuffle=False)

In [None]:
model = resnet18(weights=weights)
freeze_model(model)
model.fc = nn.Linear(in_features=512, out_features=1, bias=True)
model.to(device)

lr = 3e-4
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
n_epochs = 1

for epoch in range(n_epochs):

    print(f'Epoch #{epoch+1}')
    
    train_loss = 0
    test_loss = 0
    train_cm = 0
    test_cm = 0
    
    train_step_loss, train_step_cm = train_step(model, loader_train, loss_fn, optimizer, device)
    test_step_loss, test_step_cm = train_step(model, loader_test, loss_fn, optimizer, device)

    train_loss += train_step_loss
    test_loss += test_step_loss
    train_cm += train_step_cm
    test_cm += test_step_cm
    
    if epoch % 5 == 0:
        print(f'Train loss: [{loss_train}] Test loss: [{loss_test}]')