In [None]:
from torch.utils.data import Dataset
import torch
import pandas as pd
import albumentations
import os
import numpy as np
from PIL import Image
def add_path(data_dir,name):

  path = os.path.join(data_dir, name + '.jpg')
  if os.path.exists(path):
      return path
  path = os.path.join(data_dir, name + '.jpeg')
  if os.path.exists(path):
      return path
  path = os.path.join(data_dir, name + '.JPG')
  if os.path.exists(path):
      return path
  path = os.path.join(data_dir, name + '.png')
  if os.path.exists(path):
      return path


def load_excel(data_dir, list_file,n_class):

  image_paths = []
  labels = []
  df_tmp = pd.read_csv(list_file)
  augmented_indices = {}
  class_counts = [0]*n_class
  for c in range(n_class):
    class_counts[c] += len(df_tmp.loc[df_tmp["class"] == c].index)
    augmented_indices[c] = [idx for idx in df_tmp.loc[df_tmp["class"] == c].index]

  minority_class = min(class_counts)

  for ix in range(minority_class):
    for jx in range(len(class_counts)):
      p = ""
      image_name = df_tmp["image"][augmented_indices[jx][ix]]
      label = df_tmp["class"][augmented_indices[jx][ix]]
      
      for i in range(len(data_dir)):

        p = add_path(data_dir[i],image_name)
        if p != None :
          break

      if p == None:
        print(f"Image not found for {image_name}")
      else:
        image_paths.append(p)
        labels.append(label)
  return image_paths,labels


class DatasetGenerator(Dataset):

  def __init__(self, data_dir, list_file, transform=None, n_class=6):

    image_names,labels = load_excel(data_dir, list_file,n_class)

    self.image_names = image_names
    self.classes = list(set(labels))
    self.labels = labels
    self.n_class = n_class
    self.transform = transform

  def __getitem__(self, index):

    image_name = self.image_names[index]
    label = self.labels[index]
    image = Image.open(image_name)

    if self.transform is not None:
      image = self.transform(image)
      image  = torch.FloatTensor(image)

    return image,label

  def get_path(self,index):
    return self.image_names[index]


  def __len__(self):
     return len(self.image_names)


class DatasetGenerator2(Dataset):

  def __init__(self, data_dir, list_file, transform=None, n_class=6):

    image_names,labels = load_excel(data_dir, list_file,n_class)

    self.image_names = image_names
    self.classes = list(set(labels))
    self.labels = labels
    self.n_class = n_class
    self.transform = transform
    self.CLAHE = albumentations.Compose([ albumentations.CLAHE(clip_limit=(1,4), p= 1),])


  def __getitem__(self, index):

    image_name = self.image_names[index]
    label = self.labels[index]
    image = Image.open(image_name)


    if self.transform is not None:
        np_image = np.array(image)

        # Apply CLAHE transformation
        transformed_CLAHE = self.CLAHE(image=np_image)['image']
        transformed_CLAHE = Image.fromarray(transformed_CLAHE)
        image = self.transform(transformed_CLAHE)

    h = image.shape[1]
    w = image.shape[2]
    part1= image[:,:(h//2),:(w//2)]
    part2= image[:,:(h//2),(w//2):]
    part3= image[:,(h//2):,:(w//2)]
    part4= image[:,(h//2):,(w//2):]

    return part1,part2,part3,part4,label

  def get_path(self,index):
    return self.image_names[index]


  def __len__(self):
     return len(self.image_names)

In [None]:
import torchvision.models as models
from torchvision.models import Inception3, Inception_V3_Weights, resnet50, ResNet50_Weights
import torch.nn as nn
from torchvision.utils import save_image
import torch

def conv3x3(in_: int, out: int) -> nn.Module:
    return nn.Conv2d(in_, out, 3, padding=1)

class ConvRelu(nn.Module):
    def __init__(self, in_: int, out: int) -> None:
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.activation(x)
        return x

class Interpolate(nn.Module):
    def __init__(
        self,
        size: int = None,
        scale_factor: int = None,
        mode: str = "nearest",
        align_corners: bool = False,
    ):
        super().__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.interp(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners,
        )
        return x

class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        middle_channels: int,
        out_channels: int,
        is_deconv: bool = True,
    ):
        super().__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(
                    middle_channels, out_channels, kernel_size=4, stride=2, padding=1
                ),
                nn.ReLU(inplace=True),
            )
        else:
            self.block = nn.Sequential(
                Interpolate(scale_factor=2, mode="bilinear"),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class ResNet50_Base(nn.Module):
    def __init__(self, n_classes,name,is_deconv: bool = False,):
        super(ResNet50_Base,self).__init__()
        self.n_classes = n_classes
        self.name = name
        resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) # use pretained weight

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.encode1 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
        )
        # encoder
        self.encode2  = resnet.layer1
        self.encode3 = resnet.layer2
        self.encode4 = resnet.layer3
        self.encode5 = resnet.layer4

        self.avgpool = resnet.avgpool

        # placeholder for the gradients
        self.gradients = None

    def forward(self, input):

        e1 = self.encode1(input)
        e2 = self.encode2(self.maxpool(e1))
        e3 = self.encode3(e2)
        e4 = self.encode4(e3)
        e5 = self.encode5(e4)
        # register the hook
        #h = e5.register_hook(self.activations_hook)
        output = e5

        # register the hook
        

        

        return output

    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad

    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients




class AutoRes50(nn.Module):
    def __init__(self, n_classes,name):
        super(AutoRes50,self).__init__()
        self.n_classes = n_classes
        self.name = name
        resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

        #part1
        self.base1 = ResNet50_Base(n_classes,"base1")
        #part2
        self.base2 = ResNet50_Base(n_classes,"base2")
        #part3
        self.base3 = ResNet50_Base(n_classes,"base3")
        #part4
        self.base4 = ResNet50_Base(n_classes,"base4")

        self.avg = resnet.avgpool
        self.fc = resnet.fc

        self.classifier = nn.Sequential(
            nn.Linear(1000,n_classes)
        )

    def forward(self, part1,part2,part3,part4):
        encode1 = self.base1(part1)
        encode2 = self.base2(part2)
        encode3 = self.base3(part3)
        encode4 = self.base4(part4)

        top = torch.cat([encode1, encode2], dim=3)
        bottom = torch.cat([encode3, encode4], dim=3)
        combine = torch.cat([top, bottom], dim=2)
        x = self.avg(combine)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        result = self.classifier(x)


        return result

    def get_activations_gradient(self):
        g1 = self.base1.get_activations_gradient()
        g2 = self.base1.get_activations_gradient()
        g3 = self.base1.get_activations_gradient()
        g4 = self.base1.get_activations_gradient()

        top = torch.cat([g1, g2], dim=3)
        bottom = torch.cat([g3, g4], dim=3)
        combine_gradient = torch.cat([top, bottom], dim=2)
        return combine_gradient

    # method for the activation exctraction
    def get_activations(self, part1,part2,part3,part4):
        encode1 = self.base1(part1)
        encode2 = self.base2(part2)
        encode3 = self.base3(part3)
        encode4 = self.base4(part4)

        top = torch.cat([encode1, encode2], dim=3)
        bottom = torch.cat([encode3, encode4], dim=3)
        combine = torch.cat([top, bottom], dim=2)
        return combine





In [None]:
import torchvision
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob
import torch.nn as nn
import torch.autograd
import pathlib
import torch, torchvision
from matplotlib import rc
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.utils import save_image
#from DataGenerator import DatasetGenerator2
#from model import ResNet50_DR,VGG19_DR,ResNet50_DR_V2
#from modelauto import AutoRes50
import matplotlib.pyplot as plt
#from modelrexnet import ReXNetV2
# ================================================================ #
def train_epoch(model,dataloaders,loss_fn,loss_MSE,optimizer,device,scheduler,n_examples):
  model = model.train()
  losses = []
  correct_predictions = 0
  i = 1
  for part1,part2,part3,part4,label in tqdm(dataloaders):
    part1 = part1.to(device)
    part2 = part2.to(device)
    part3 = part3.to(device)
    part4 = part4.to(device)
    labels = label.to(device)
    print(labels)

    result = model(part1,part2,part3,part4)
    _, preds = torch.max(result, dim=1)
    loss = loss_fn(result, labels)

    correct_predictions += torch.sum(preds == labels)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  scheduler.step()

  return model, correct_predictions.double() / n_examples ,np.mean(losses) #

# ================================================================ #
def eval_model(model, dataloaders, loss_fn, loss_MSE, device, n_examples):
    model = model.eval()
    losses = []
    correct_predictions = 0
    with torch.no_grad():
        for part1,part2,part3,part4,label in tqdm(dataloaders):
            part1 = part1.to(device)
            part2 = part2.to(device)
            part3 = part3.to(device)
            part4 = part4.to(device)
            labels = label.to(device)
            print(labels)

            result = model(part1,part2,part3,part4)
            _, preds = torch.max(result, dim=1)
            loss = loss_fn(result, labels)
        
            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())
    return correct_predictions.double() / n_examples, np.mean(losses)

# ================================================================ #
def checkpoint_path(filename,model_name):

  checkpoint_folderpath = pathlib.Path(f'/checkpoint-Resnet-6class/{model_name}')
  print(checkpoint_folderpath)
  checkpoint_folderpath.mkdir(exist_ok=True,parents=True)
  return checkpoint_folderpath/filename
# ================================================================ #

def train_model(model, dataloaders_train, dataloaders_val,  dataset_sizes_train,  dataset_sizes_val, device, n_epochs=50): # train ต่อจาก epoch ที่18
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    loss_fn = nn.CrossEntropyLoss(reduction='mean').to(device)
    loss_MSE = nn.MSELoss().to(device)
    best_model_path = checkpoint_path('best_model_state.ckpt',"ResNet50_DR_Auto")

    #print(model)
    train_accuracy = []
    train_losses = []
    val_accuracy = []
    val_losses = []

    best_accuracy = 0
    for epoch in range(1,n_epochs+1):
      print(f'Epoch {epoch }/{n_epochs}')
      print('-' * 10)
      model, train_acc, train_loss = train_epoch(model, dataloaders_train, loss_fn, loss_MSE,optimizer, device, scheduler,n_examples=dataset_sizes_train)
      print(f'Train loss {train_loss} accuracy {train_acc}')
      val_acc, val_loss = eval_model(model,dataloaders_val,loss_fn, loss_MSE, device,n_examples=dataset_sizes_val)
      print(f'validation   loss {val_loss} accuracy {val_acc}')
      train_accuracy.append(train_acc.item())
      train_losses.append(train_loss)
      val_accuracy.append(val_acc.item())
      val_losses.append(val_loss)

      torch.save(model.state_dict(), checkpoint_path('best_model_state_'+str(epoch)+'.ckpt',"ResNet50_DR_Auto"))
      if val_acc> best_accuracy:
        torch.save(model.state_dict(), best_model_path)
        best_accuracy = val_acc
    #print(f'Best val accuracy: {best_accuracy}')
    model.load_state_dict(torch.load(best_model_path))
    #print(f"train_accuracy_each_epoch {train_accuracy}")
    #print(f"train_losses_each_epoch {train_losses}")
    #print(f"val_accuracy_each_epoch {val_accuracy}")
    #print(f"val_losses_each_epoch {val_losses}")

    plot_metrics(train_accuracy, val_accuracy, 'Accuracy')
    plot_metrics(train_losses, val_losses, 'Loss')
    return model
# ================================================================ #

def plot_metrics(train_metrics, val_metrics, metric_name):
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo-', label=f'Training {metric_name}')
    plt.plot(epochs, val_metrics, 'ro-', label=f'Validation {metric_name}')
    plt.xticks(epochs)
    plt.title(f'Training and Validation {metric_name}')
    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.legend()
    plt.savefig(metric_name)



if __name__ == '__main__':
  dir =['/content/drive/MyDrive/combinedata/content/combine',
              '/content/drive/MyDrive/TRAIN/train',
              '/content/drive/MyDrive/TEST',
              '/content/drive/MyDrive/augment_fold/augment_fold1',
              '/content/drive/MyDrive/augment_fold/augment_fold2',
              ] #  6 class

 #  6 class

  label_train_file ='/content/drive/MyDrive/train_folder/fold1.csv'
  label_val_file ='/content/drive/MyDrive/train_folder/f2.csv'


  # ================================================================ #
  # Data augmentation
  train_transforms = transforms.Compose([
        transforms.Resize((448,448)),
        transforms.ToTensor(),
      ])
  val_trasform = transforms.Compose([
        transforms.Resize((448,448)),
        transforms.ToTensor(),
      ])
  # ================================================================ #



  Imagedataset_train = DatasetGenerator2(data_dir=dir, list_file=label_train_file,
                             n_class= 6,transform=train_transforms)

  Imagedataset_val = DatasetGenerator2(data_dir=dir, list_file=label_val_file,
                             n_class= 6,transform=val_trasform)

  dataloaders_train= torch.utils.data.DataLoader(Imagedataset_train, batch_size=4, shuffle=False, num_workers=2)
  dataloaders_val= torch.utils.data.DataLoader(Imagedataset_val, batch_size=4, shuffle=False, num_workers=2)

  dataset_train_sizes = len(Imagedataset_train)
  dataset_val_sizes = len(Imagedataset_val)

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  # ================================================================ #
  print(torch.cuda.is_available())
  print("Device: ", device)
  print(f"train size: {len(Imagedataset_train)}")
  print(f"val size: {len(Imagedataset_val)}")

  model_name = 'ResNet50_DR_Auto'
  # ================================================================ #
  model = AutoRes50(n_classes=6,name=model_name)
  model.to(device)
  # ================================================================ #
  model = train_model(model,dataloaders_train, dataloaders_val, dataset_train_sizes, dataset_val_sizes, device, n_epochs=20)
  # ================================================================ #

