In [1]:
#test time augument
#ensemble
#treshold
import timm
import sys
import os
import pandas as pd
import torch
import matplotlib.pyplot as plt
from datetime import datetime
import pickle
import torchvision
from sklearn.metrics import f1_score
import torch.optim as optim
from tqdm import tqdm
from time import time, sleep

from torch.utils.data import DataLoader
from dataset import EmbryoDataset
from models import CustomVit
from augmentations import Cutout
from utils import SAM,LR_Scheduler
from utils import RandAugment,PadAndResize

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
batch_size = 16
img_size = 224
rho = 0.05
cuda_device_index = 0
learning_rate = 0.001
momentum = 0.9
warmup_epochs = 3
weight_decay = 0.005
epochs = 100
num_workers = 8 # workers for dataloader
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

root = "/data/train"

train_file_0="train_data_fold_3_0.pkl"
validation_file_0="validation_data_fold_3_0.pkl"
train_file_1="train_data_fold_3_1.pkl"
validation_file_1="validation_data_fold_3_1.pkl"

n = 5 # for randaugment
m = 12 # for randaugment
checkpoint_dir = "Models/ensemble_fold_3"
if os.path.isdir(checkpoint_dir) == False:
    os.makedirs(checkpoint_dir)
device = torch.device("cuda:" + str(cuda_device_index) if torch.cuda.is_available() else "cpu")

model = timm.create_model("hf_hub:timm/maxvit_small_tf_224.in1k", pretrained=True,num_classes=2)
labels=[0,1]


In [None]:
transforms_train = torchvision.transforms.Compose([
    PadAndResize((img_size,img_size)), 
    torchvision.transforms.RandomVerticalFlip(p=0.3),
    torchvision.transforms.RandomHorizontalFlip(p=0.3),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
])

transforms_train.transforms.insert(0, RandAugment(n, m))
transforms_train.transforms.append(Cutout(n_holes=5, length=32, p=0.3))

transforms_validation = torchvision.transforms.Compose([
    PadAndResize((img_size,img_size)), 
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
])

with open(train_file_0, 'rb') as file:
    train_data_0 = pickle.load(file)
    
with open(validation_file_0, 'rb') as file:
    validation_data_0 = pickle.load(file)
    
with open(train_file_1, 'rb') as file:
    train_data_1 = pickle.load(file)
    
with open(validation_file_1, 'rb') as file:
    validation_data_1 = pickle.load(file)
    

    
dataset_train_0 = EmbryoDataset(root=root, 
                         fold_splitter=train_data_0,
                         transforms=transforms_train)
dataset_validation_0 = EmbryoDataset(root=root, 
                         fold_splitter=validation_data_0,
                         transforms=transforms_validation)

dataset_train_1 = EmbryoDataset(root=root, 
                         fold_splitter=train_data_1,
                         transforms=transforms_train)

dataset_validation_1 = EmbryoDataset(root=root, 
                         fold_splitter=validation_data_1,
                         transforms=transforms_validation)

print(len(dataset_validation_0))
print(len(dataset_validation_1))



In [5]:
dataloader_train_0 = DataLoader(dataset_train_0, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                             drop_last=True)
dataloader_valid_0 = DataLoader(dataset_validation_0, batch_size=1, shuffle=True, num_workers=num_workers)

In [6]:
dataloader_train_1 = DataLoader(dataset_train_1, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                             drop_last=True)
dataloader_valid_1 = DataLoader(dataset_validation_1, batch_size=1, shuffle=True, num_workers=num_workers)

In [7]:
def print_log(file, text = "", timp=True):
    string = ""
    if timp:
        timestamp = time()
        dt_object = datetime.fromtimestamp(timestamp)
        string += str(dt_object) + ": "
    string += text + "\n"
    
    f = open(file, "a")
    f.write(string)
    f.close()

In [8]:
def print_to_log(epoch, train_loss, valid_dice, valid_f1,n,n1,n2, start_time):
    file_log = checkpoint_dir+"/train_log.txt"
    print_log(file_log)
    print_log(file_log, "epoch: {}".format(epoch), False)
    print_log(file_log, "train loss: {}". format(train_loss))
    print_log(file_log, "valid metric: {}". format(valid_dice))
    print_log(file_log, "valid F1 score: {}". format(valid_f1))
    print_log(file_log, "Total correct: {}". format(n))
    print_log(file_log, "Total correct F: {}". format(n1))
    print_log(file_log, "Total correct T: {}". format(n2))
    # print_log(file_log, "test metric: {}". format(test_dice))
    
    end_time = time()
    print_log(file_log, "This epoch took {} s".format("{:.4f}".format(end_time-start_time)))
    print_log(file_log, "", False)

In [9]:
def make_plot(arr, name, time):
    plt.figure(figsize=(10,10))
    plt.title(time + ' ' + name)
    plt.plot(arr,label=time)
    plt.xlabel("iterations")
    plt.ylabel(name)
    plt.legend()
    plt.savefig(checkpoint_dir+"/Classification " + time + ' ' + name + '.png')
    plt.close()

In [10]:
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
base_optimizer = optim.SGD
optimizer=SAM(model.parameters(), 
                base_optimizer, rho=rho, lr=learning_rate, momentum=momentum, 
                weight_decay=weight_decay)

# optimizer = optim.AdamW(model.parameters(),
#                                lr=learning_rate, weight_decay=weight_decay)
epochs = 1000
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

var = tqdm(range(epochs))
model = model.to(device)


  0%|                                                                                                              | 0/1000 [00:00<?, ?it/s]

In [11]:
val_interval = 1
val_interval_test = 5

best_metric_0 = -1
best_metric_f1_0 = -1
best_metric_epoch_0 = -1
best_metric_f1_epoch_0 = -1

best_metric_1 = -1
best_metric_f1_1 = -1
best_metric_epoch_1 = -1
best_metric_f1_epoch_1 = -1

epoch_loss_values = []
metric_values_0 = []
metric_values_f1_0 = []
metric_values_1 = []
metric_values_f1_1 = []
epochs = 750

In [12]:
# for x in dataloader_train:
#     print(x)

In [None]:
for epoch in var:
    model.train()
    epoch_loss=0
    step=0
    start_time=time()
    for inputs_1,labels_1 in dataloader_train_1:
        #optimizer.zero_grad()
        inputs_0,labels_0 = next(iter(dataloader_train_0))
        inputs=torch.cat((inputs_0,inputs_1))
        labels=torch.cat((labels_0,labels_1))
        #print(days.shape)
        order=torch.randperm(batch_size*2)
        inputs=inputs[order][:][:]
        labels=labels[order][:][:]
        step+=1
        
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.first_step(zero_grad=True)
        epoch_loss += loss.item()
        epoch_len = len(dataset_train_1) // dataloader_train_1.batch_size
        var.set_description(f"{step}/{epoch_len}, train_loss: {loss.item()}")
        
        outputs2 = model(inputs)
        loss = loss_function(outputs2, labels)
        loss.backward()
        optimizer.second_step(zero_grad=True)
    scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    var.set_description("epoch {} average loss: {}".format(epoch + 1, epoch_loss))
    if (epoch + 1) % val_interval == 0:
            model.eval()
            num_correct_0 = 0.0
            num_correct_1 = 0.0
            num_correct=0.0
            metric_count = 0
            f1_output = []
            f1_labels = []
            for val_images, val_labels in dataloader_valid_0:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                with torch.no_grad():
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct_0 += value.sum().item()
                    f1_output.append(val_outputs.argmax(dim=1).cpu().item())
                    f1_labels.append(val_labels.cpu().item())
            for val_images, val_labels in dataloader_valid_1:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                with torch.no_grad():
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct_1 += value.sum().item()
                    f1_output.append(val_outputs.argmax(dim=1).cpu().item())
                    f1_labels.append(val_labels.cpu().item())
            
            num_correct=num_correct_0+num_correct_1
            metric = num_correct / metric_count
            metric_values_0.append(metric)
            f1 = f1_score(f1_labels, f1_output, average='macro')
            metric_values_f1_0.append(f1)
            print(f"Images correct:  {int(num_correct)}: T:{int(num_correct_1)}, F:{int(num_correct_0)} ")
            if metric > best_metric_0:
                best_metric_0 = metric
                best_metric_epoch_0 = epoch + 1
                torch.save(model.state_dict(), checkpoint_dir+"/model_best_0.pth")

            if f1 > best_metric_f1_0:
                best_metric_f1_0 = f1
                best_metric_f1_epoch_0 = epoch + 1
                torch.save(model.state_dict(), checkpoint_dir+"/model_best_f1_0.pth")
                
      
            print_to_log(epoch, epoch_loss_values[-1], metric_values_0[-1], metric_values_f1_0[-1],num_correct,num_correct_0,num_correct_1, start_time)

            make_plot(epoch_loss_values, "Lossss", "train")
            make_plot(metric_values_0, "Metric_0", "valid")
            make_plot(metric_values_f1_0, "F1 score_0", "valid")
        # make_plot(epoch_loss_values, metric_values)

    if (epoch+1) % 25 == 0:
        torch.save(model.state_dict(), checkpoint_dir+"/last_metric_model.pth")

print_log(checkpoint_dir+"/train_log.txt", f"Training completed, best_metric: {best_metric} at epoch: {best_metric_epoch}", False)    
