In [1]:
from utils.split import RandomSplitPt
from utils.dataset import HemorrhageDataset
from utils.model import HemoResNet50, HemoResNet18
from utils.metric import hemorrhage_metrics
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import numpy as np
import pandas as pd

In [2]:
# calculate weight
train_df = pd.read_csv("./Blood_data/train.csv")
num_each_class = train_df[["ich", "ivh", "sah", "sdh", "edh"]].sum().values
total_num = num_each_class.sum()
weight = num_each_class / total_num
print(weight)

[0.22552461 0.15764975 0.22293018 0.3169401  0.07695536]


In [3]:
random_split_pt = RandomSplitPt()
train, val = random_split_pt.randomly_split()

In [4]:
train_dataset = HemorrhageDataset(train, stack_img = True, mode="train")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers = 16, pin_memory=True)

val_dataset = HemorrhageDataset(val, stack_img = True, mode="val")
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers = 16, pin_memory=True)

In [5]:
device = "cuda"
num_class = 5
max_epochs = 50
pos_weight = torch.Tensor(weight).to(device)

model = HemoResNet18(num_class)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.to(device)

HemoResNet18(
  (base_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [None]:
best_val_f2 = 0

for epoch in range(1, max_epochs + 1):
    model.train()
    train_pred = []
    train_true = []
    train_loss_accum = 0
    for i, (data, label) in enumerate(train_loader, 1):
        print(f"Process {i} / {len(train_loader)}    ", end="\r")
        data = data.to(device)
        label = label.to(device).float()
        
        logits = model(data)
        loss = F.binary_cross_entropy_with_logits(logits, label, pos_weight=pos_weight)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss_accum += loss.item()
        train_pred.append(logits.cpu().detach().numpy())
        train_true.append(label.cpu().numpy())
    
    model.eval()
    val_pred = []
    val_true = []
    val_loss_accum = 0
    with torch.no_grad():
        for i, (data, label) in enumerate(val_loader, 1):
            print(f"Process {i} / {len(val_loader)}    ", end="\r")
            data = data.to(device)
            label = label.to(device).float()
            
            logits = model(data)
            loss = F.binary_cross_entropy_with_logits(logits, label, pos_weight= pos_weight)
            
            val_loss_accum += loss.item()
            val_pred.append(logits.cpu().detach().numpy())
            val_true.append(label.cpu().numpy())
    
    train_pred = np.concatenate((train_pred))
    train_true = np.concatenate((train_true))
    val_pred = np.concatenate((val_pred))
    val_true = np.concatenate((val_true))
    
    train_metric = hemorrhage_metrics(train_pred, train_true)
    val_metric = hemorrhage_metrics(val_pred, val_true)
    
    if val_metric['f2'] > best_val_f2:
        best_val_f2 = val_metric['f2']
        print(F"BEST AT epoch {epoch:3d} || VAL F2 = {val_metric['f2']:.4f}")
        torch.save(model.state_dict(), "./model/resnet18_basic.pth")

    print(f"[epoch {epoch:3d}] TRAIN acc {train_metric['acc']:.4f} f2 {train_metric['f2']:.4f} loss {train_loss_accum:.4f} || VAL acc {val_metric['acc']:.4f} f2 {val_metric['f2']:.4f} loss {val_loss_accum:.4f}")

[epoch   1] TRAIN acc 0.8981 f2 0.042709 loss 51.1904 || VAL acc 0.8981 f2 0.024192 loss 16.2777
[epoch   2] TRAIN acc 0.9089 f2 0.160673 loss 42.4626 || VAL acc 0.9124 f2 0.219067 loss 15.9440
Process 25 / 593    