# ResNet Train-Validation

In [1]:
import torch
from torchvision.transforms import (
    Compose, ToTensor, Resize, Normalize, RandomResizedCrop, RandomHorizontalFlip, RandomVerticalFlip
)
from torchvision.models import resnet50

from data import HerniaDataset, PhaseMapper
from utils import ResnetTrainer, ResnetEvaluator

import random

%load_ext autoreload
%autoreload 2

## Dataset

In [2]:
root = '../../surgery_hernia_train_test/'
videos = ['RALIHR_surgeon01_fps01_{:04}'.format(i + 1) for i in range(70)]
random.shuffle(videos)
videos += ['RALIHR_surgeon01_fps01_{:04}'.format(i + 1) for i in range(70, 120)]

In [3]:
mapper = PhaseMapper('../configs/all_labels_hernia_merged_7.csv')
mapper.get_merged_labels()

Unnamed: 0,labels
0,mesh placement
1,out of body
2,peritoneal closure
3,peritoneal scoring
4,preperioneal dissection
5,reduction of hernia
6,transitionary idle


In [4]:
input_shape = (224, 224)
mean = [0.41757566,0.26098573,0.25888634]
std = [0.21938758,0.1983,0.19342837]
train_transform = Compose([
    #Resize(input_shape),
    RandomResizedCrop(size=input_shape, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    ToTensor(),
    RandomHorizontalFlip(),
    Normalize(mean, std)
])
valid_transform = Compose([
    Resize(input_shape),
    ToTensor(),
    Normalize(mean, std)
])
train_set = HerniaDataset(root, videos[:65], transforms=train_transform, class_map=mapper)
valid_set = HerniaDataset(root, videos[65:70], transforms=valid_transform, class_map=mapper)
test_set = HerniaDataset(root, videos[70:120], transforms=valid_transform, class_map=mapper)
len(train_set), len(valid_set), len(test_set)

(201991, 13066, 147205)

## Model

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [6]:
model = resnet50(pretrained=True)
model.fc = torch.nn.Linear(2048, 7)

## Training

In [7]:
trainer = ResnetTrainer(model, device)

In [None]:
hist = trainer.train(
    train_set, valid_set, 
    num_epochs=10, 
    batch_size=16, 
    learning_rate=1e-5, 
    run_name='resnet50-p7-v120-b16-lr1em5-sa', 
    num_workers=4, 
    prefetch_factor=2
)
hist

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mharrypotter1501[0m ([33meezklab[0m). Use [1m`wandb login --relogin`[0m to force relogin


Running resnet50-p7-v120-b64-lr1em5-sa
Datasets: num_train = 201991, num_validation = 13066
Main metric: valid_f1_weighted
Epoch 1/10: validating 99.88% (batch 817/817)5)



Epoch 1/10: train_loss 0.30350582, train_accuracy 0.89640628, train_f1_macro 0.87818865, train_f1_weighted 0.89478125, valid_loss 0.84388769, valid_accuracy 0.72715445, valid_f1_macro 0.64753303, valid_f1_weighted 0.71676099, _timestamp 1659045747.00000000, _runtime 2336.00000000
Epoch 2/10: validating 99.88% (batch 817/817)5)



Epoch 2/10: train_loss 0.12030736, train_accuracy 0.96002792, train_f1_macro 0.95161936, train_f1_weighted 0.95984517, valid_loss 0.81340564, valid_accuracy 0.74705342, valid_f1_macro 0.72433417, valid_f1_weighted 0.75496539, _timestamp 1659048056.00000000, _runtime 4645.00000000
Epoch 3/10: train_loss 0.07395071, train_accuracy 0.97571179, train_f1_macro 0.96908541, train_f1_weighted 0.97565492, valid_loss 0.90524502, valid_accuracy 0.74636461, valid_f1_macro 0.70294618, valid_f1_weighted 0.74277215, _timestamp 1659050358.00000000, _runtime 6947.00000000
Epoch 4/10: train_loss 0.05138595, train_accuracy 0.98334084, train_f1_macro 0.97794335, train_f1_weighted 0.98332672, valid_loss 1.17999471, valid_accuracy 0.73434869, valid_f1_macro 0.66885208, valid_f1_weighted 0.72437877, _timestamp 1659052665.00000000, _runtime 9254.00000000
Epoch 5/10: training 95.81% (batch 12097/12625)

## Evaluation

In [10]:
model.load_state_dict(torch.load('./model/resnet50-p7-v120-b16-lr1em5-sa.pt'))

<All keys matched successfully>

In [11]:
evaluator = ResnetEvaluator(model, device, mapper.get_merged_labels()['labels'])

In [12]:
res, report = evaluator.evaluate(test_set, num_workers=4, prefetch_factor=32)
res

Testing ResNet
Datasets: num_test = 147205
Testing 100.00% (batch 147201/147205)

{'time': 0.014633959485826874,
 'accuracy': 0.7708297951835875,
 'f1_macro': 0.7659636026471139,
 'f1_weighted': 0.7750889642501818}

In [13]:
print(report)

                         precision    recall  f1-score   support

         mesh placement       0.82      0.72      0.76     34092
            out of body       0.93      0.98      0.96       906
     peritoneal closure       0.88      0.90      0.89     36029
     peritoneal scoring       0.82      0.65      0.73      6745
preperioneal dissection       0.70      0.62      0.66     18409
    reduction of hernia       0.83      0.82      0.82     36725
     transitionary idle       0.45      0.67      0.54     14299

               accuracy                           0.77    147205
              macro avg       0.78      0.77      0.77    147205
           weighted avg       0.79      0.77      0.78    147205

