In [None]:
import torchvision
from torchvision import transforms, models
import torch
import torch.nn.functional as F
import torch.nn as nn   

from sklearn.model_selection import KFold

from rembg.bg import remove
from focal_loss.focal_loss import FocalLoss

import pandas as pd

import os
import shutil

from tqdm import tqdm
import numpy as np

import matplotlib.pyplot as plt
import time
import copy


##Data preparation 

In [None]:
train_transforms = transforms.Compose([
    transforms.CenterCrop(224),
#    transforms.RandomHorizontalFlip(),
#    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=(0, 180)),
#    transforms.RandomPerspective(),
#    transforms.RandomInvert(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
#    transforms.RandomInvert(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)
val_dataset = torchvision.datasets.ImageFolder(val_dir, val_transforms)

batch_size = 8
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
train_another_dataset = torchvision.datasets.ImageFolder(data_train, train_transforms)

batch_size = 8
train_another_dataloader = torch.utils.data.DataLoader(
    train_another_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


In [None]:
len(train_dataloader), len(train_end_dataset)

In [None]:
X_batch, y_batch = next(iter(train_dataloader))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
plt.imshow(X_batch[0].permute(1, 2, 0).numpy() * std + mean);

In [None]:
def show_input(input_tensor, title=''):
    image = input_tensor.permute(1, 2, 0).numpy()
    image = std * image + mean
    plt.imshow(image.clip(0, 1))
    plt.title(title)
    plt.show()
    plt.pause(0.001)

X_batch, y_batch = next(iter(train_dataloader))

for x_item, y_item in zip(X_batch, y_batch):
    show_input(x_item, title=class_names[y_item])

##Train function

In [None]:
def train_model(model, loss, optimizer, scheduler, num_epochs, data_tr, data_val, device):
    for epoch in range(num_epochs):
        print('Epoch {}/{}:'.format(epoch, num_epochs - 1), flush=True)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_dataloader
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                dataloader = val_dataloader
                model.eval()   # Set model to evaluate mode

            running_loss = 0.
            running_acc = 0.

            # Iterate over data.
            for inputs, labels in tqdm(dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # forward and backward
                with torch.set_grad_enabled(phase == 'train'):
                    preds = model(inputs)
                    loss_value = loss(preds, labels)
                    preds_class = preds.argmax(dim=1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss_value.backward()
                        optimizer.step()

                # statistics
                running_loss += loss_value.item()
                running_acc += (preds_class == labels.data).float().mean()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)

            epl = 0
            epcc = 0
            if phase == 'val':
                epl = epoch_loss
                epcc = epoch_acc.item()

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc), flush=True)

    return [model, epl, epcc]

##Mod-resnet

In [None]:
def modified_resnet():
    model = models.resnet152(pretrained=True)
    #model.fc.register_forward_hook(lambda m, inp, out: F.dropout(out, p=0.5, training=m.training))

    # Disable grad for all conv layers
    for param in model.parameters():
        param.requires_grad = False

    model.avgpool = torch.nn.Sequential(
#        torch.nn.Dropout(),
#        summ_layer(),
        torch.nn.AdaptiveAvgPool2d((1, 1))
    )


    model.fc =  torch.nn.Sequential(
        torch.nn.Dropout(),
        torch.nn.Linear(model.fc.in_features, 2))
    return model

##K-fold validation

In [None]:
kfold = KFold(5)

In [None]:
def k_fold():
  full_loss = []
  full_acc = []
  for train_index, test_index in kfold.split(train_end_dataset):  
    model = modified_resnet()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    #loss = torch.nn.CrossEntropyLoss()
    loss = FocalLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    
    train_tensor = [train_end_dataset[i] for i in train_index]
    val_tensor = [train_end_dataset[i] for i in test_index]

    batch_size = 10
    train_dataloader = torch.utils.data.DataLoader(
    train_tensor, batch_size=batch_size, shuffle=True, num_workers=2)

    val_dataloader = torch.utils.data.DataLoader(
    val_tensor, batch_size=batch_size, shuffle=False, num_workers=2)
    
    _, ep_loss, ep_acc = train_model(model, loss, optimizer, scheduler, num_epochs=20, data_tr = train_dataloader, data_val = val_dataloader, device = device)

    full_loss.append(ep_loss)
    full_acc.append(ep_acc)

  return [np.mean(full_loss), np.mean(full_acc)]

##Work

In [None]:
result_metrics = k_fold()