In [1]:
import matplotlib.pyplot as plt
import numpy as np
import helper

import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
import torch
import pandas as pd
from torchinfo import summary
from PIL import Image
from torchvision.transforms import ToTensor
from glob import glob
from torch.utils.data import Dataset, DataLoader, random_split
from copy import copy
from collections import defaultdict
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
import math
from torcheval.metrics import BinaryAccuracy
import os
import torchmetrics
import timm
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
batch_size=2
image_count=20
img_size=512
tf = ToTensor()


In [2]:
train_image_transition_path='../../data/ensemble_tile/CN/train/transition/5x/*'
train_image_not_transition_path='../../data/ensemble_tile/CN/train/not_transition/5x/*'
test_image_transition_path='../../data/ensemble_tile/CN/test/transition/5x/*'
test_image_not_transition_path='../../data/ensemble_tile/CN/test/not_transition/5x/*'
class CustomDataset(Dataset):
    def __init__(self, image_list, label_list):
        self.img_path = image_list

        self.label = label_list


    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        image_tensor = torch.empty((image_count,3, img_size, img_size))
        
        image_file_list = glob(self.img_path[idx]+'/*.jpg')
        image_index = torch.randint(low=0, high=len(
            image_file_list)-1, size=(image_count,))
        count = 0
        for index in image_index:
            image = 1-tf(Image.open(image_file_list[index]).resize((img_size,img_size)))
            image_tensor[count] = image
            count += 1
        label_tensor =  self.label[idx]
        return image_tensor, label_tensor
    
train_image_transition_path='../../data/ensemble_tile/CN/train/transition/5x/*'
train_image_not_transition_path='../../data/ensemble_tile/CN/train/not_transition/5x/*'
test_image_transition_path='../../data/ensemble_tile/CN/test/transition/5x/*'
test_image_not_transition_path='../../data/ensemble_tile/CN/test/not_transition/5x/*'

train_image_list = []
train_label_list = []
image_abnormal_list = glob(train_image_transition_path)
image_abnormal_label = torch.ones(len(image_abnormal_list), 1)
image_normal_list = glob(train_image_not_transition_path)
image_normal_label = torch.zeros(len(image_normal_list), 1)
train_image_list.extend(image_abnormal_list)
train_image_list.extend(image_normal_list)
train_label_list.extend(image_abnormal_label)
train_label_list.extend(image_normal_label)

test_image_list = []
test_label_list = []
image_abnormal_list = glob(test_image_transition_path)
image_abnormal_label = torch.ones(len(image_abnormal_list), 1)
image_normal_list = glob(test_image_not_transition_path)
image_normal_label = torch.zeros(len(image_normal_list), 1)
test_image_list.extend(image_abnormal_list)
test_image_list.extend(image_normal_list)
test_label_list.extend(image_abnormal_label)
test_label_list.extend(image_normal_label)

train_dataset = CustomDataset(train_image_list, F.one_hot(torch.tensor(train_label_list).to(torch.int64)))

test_dataset = CustomDataset(test_image_list, F.one_hot(torch.tensor(test_label_list).to(torch.int64)))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [3]:
class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('inception_resnet_v2', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
    
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(AttentionMILModel, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(image_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)

    def forward(self, inputs):
        batch_size, num_tiles, channels, height, width = inputs.size()
        
        # Flatten the inputs
        inputs = inputs.view(-1, channels, height, width)
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size * num_tiles, 2048, 1, 1)
        
        # Reshape features
        features = features.view(batch_size, num_tiles, -1)  # Shape: (batch_size, num_tiles, 2048)
        
        # Attention mechanism
        attention_weights = self.attention(features)  # Shape: (batch_size, num_tiles, 1)
        attention_weights = F.softmax(attention_weights, dim=1)  # Normalize attention weights
        
        # Apply attention weights to features
        attended_features = torch.sum(features * attention_weights, dim=1)  # Shape: (batch_size, 2048)
        
        # Classification layer
        logits = self.classification_layer(attended_features)  # Shape: (batch_size, num_classes)
        
        return logits  
Feature_Extractor=FeatureExtractor()
model = AttentionMILModel(2,1536,Feature_Extractor)
model = model.to(device)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
summary(model,(batch_size,image_count,3,img_size,img_size))

Layer (type:depth-idx)                                       Output Shape              Param #
AttentionMILModel                                            [2, 2]                    --
├─FeatureExtractor: 1-1                                      [40, 1536]                --
│    └─Sequential: 2-1                                       [40, 1536]                --
│    │    └─ConvNormAct: 3-1                                 [40, 32, 255, 255]        928
│    │    └─ConvNormAct: 3-2                                 [40, 32, 253, 253]        9,280
│    │    └─ConvNormAct: 3-3                                 [40, 64, 253, 253]        18,560
│    │    └─MaxPool2d: 3-4                                   [40, 64, 126, 126]        --
│    │    └─ConvNormAct: 3-5                                 [40, 80, 126, 126]        5,280
│    │    └─ConvNormAct: 3-6                                 [40, 192, 124, 124]       138,624
│    │    └─MaxPool2d: 3-7                                   [40, 192, 61, 61] 

In [4]:
MIN_loss=5000
train_loss_list=[]
val_loss_list=[]
train_acc_list=[]
sig=nn.Sigmoid()
val_acc_list=[]


for epoch in range(1000):
    train=tqdm(train_dataloader)
    count=0
    running_loss = 0.0
    acc_loss=0
    model.train()
    for x, y in train:
        
        y = y.to(device).float()
        count+=1
        x=x.to(device).float()
        optimizer.zero_grad()  # optimizer zero 로 초기화
        predict = model(x).to(device)
        cost = F.cross_entropy(predict.softmax(dim=1), y) # cost 구함
        acc=accuracy(predict.softmax(dim=1).argmax(dim=1),y.argmax(dim=1))
        cost.backward() # cost에 대한 backward 구함
        optimizer.step() 
        running_loss += cost.item()
        acc_loss+=acc
        train.set_description(f"epoch: {epoch+1}/{1000} Step: {count+1} loss : {running_loss/count:.4f} accuracy: {acc_loss/count:.4f}")
    train_loss_list.append((running_loss/count))
    train_acc_list.append((acc_loss/count).cpu().detach().numpy())
#validation
    val=tqdm(validation_dataloader)
    model.eval()
    count=0
    val_running_loss=0.0
    acc_loss=0
    with torch.no_grad():
        for x, y in val:
            y = y.to(device).float()
            count+=1
            x=x.to(device).float()
            predict = model(x).to(device)
            cost = F.cross_entropy(predict.softmax(dim=1), y) # cost 구함
            acc=accuracy(predict.softmax(dim=1).argmax(dim=1),y.argmax(dim=1))
            val_running_loss+=cost.item()
            acc_loss+=acc
            val.set_description(f"Validation epoch: {epoch+1}/{1000} Step: {count+1} loss : {val_running_loss/count:.4f}  accuracy: {acc_loss/count:.4f}")
        val_loss_list.append((val_running_loss/count))
        val_acc_list.append((acc_loss/count).cpu().detach().numpy())
    if epoch%100==5:
        plt.figure(figsize=(10,5))
        plt.subplot(1, 2, 1) 
        plt.title('loss_graph')
        plt.plot(np.arange(epoch+1),train_loss_list,label='train_loss')
        plt.plot(np.arange(epoch+1),val_loss_list,label='validation_loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.ylim([0, 1]) 
        plt.legend()
        plt.subplot(1, 2, 2)  
        plt.title('acc_graph')
        plt.plot(np.arange(epoch+1),train_acc_list,label='train_acc')
        plt.plot(np.arange(epoch+1),val_acc_list,label='validation_acc')
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.ylim([0, 1]) 
        plt.legend()
        plt.show()
        
        
    if MIN_loss>(val_running_loss/count):
        torch.save(model.state_dict(), '../../model/image_5x/attention_MIL_callback.pt')
        MIN_loss=(val_running_loss/count)
torch.save(model.state_dict(), '../../model/attention_MIL.pt')

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

KeyboardInterrupt: 