In [None]:
!pip install vit-pytorch
!pip install timm
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import os
#os.environ['KMP_DUPLICATE_LIB_OK']='True'
from PIL import Image
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from vit_pytorch import ViT
import albumentations
import albumentations.pytorch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models, transforms
from matplotlib import pyplot as plt
import cv2
import numpy as np
import copy
import time
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix
from collections import defaultdict
import timm
from scipy.special import softmax
from sklearn.model_selection import KFold, StratifiedKFold

In [None]:
BASE = "/Classification/data/"
path_to_images =  BASE + "remove_background/" 
SEGMENT = "body"
CLASSIFIER = 'vit'
path_for_models = BASE + "models/"
path_to_save_values = BASE + "remove_background/results/"

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
timm.list_models('*vit*')

In [None]:
# processing dataset metadata
def process_dataset_meta_info(path, part):
    true_files = []
    true_file_paths = os.listdir(path + 'WFT_' + part + '/' )
    true_file_paths.sort()
    for item in true_file_paths:
        if item[-4:] == ".png" or item[-4:] == ".jpg":
            true_files.append(path + 'WFT_' + part + '/'  + item)
    true_labels = [1]*len(true_files)

    negative_files = []
    negative_file_paths = os.listdir(path + 'NOTWFT_' + part + '/' )
    negative_file_paths.sort()
    for item in negative_file_paths:
        if item[-4:] == ".png" or item[-4:] == ".jpg":
            negative_files.append(path + 'NOTWFT_' + part + '/' + item)
    negative_labels = [0]*len(negative_files)
    
    files = true_files + negative_files
    labels = true_labels + negative_labels
    data_list = []
    
    print('len.......', len(files))
    for idx in range(0, len(files)):
        data = {'file_path': files[idx],
                'label' : labels[idx]}
        data_list.append(data)
    
    data_list = np.array(data_list)
    return data_list # change the return of the original function

In [None]:
# reading thrip dataset
class ThripDataset(Dataset):
    def __init__(self, file_paths, labels, is_train, required_transform=None, optional_transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.required_transform = required_transform
        self.optional_transform = optional_transform
        self.is_train = is_train

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx] # extract the label
        file_path = self.file_paths[idx] # extract the filepath
        #image = Image.open(file_path) # read the image with PIL
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.optional_transform and self.is_train:
            image = self.optional_transform(image=image)['image']

        if self.required_transform:
            image = self.required_transform(image)
        
        data_sample = {'x': image, 'y': label, 'id' : file_path[-7:-4]} # added id to the data_sample
        
        return data_sample

In [None]:
def evaluate(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average = "macro")
    mat = confusion_matrix(y_true, y_pred)

    print("accuracy:", acc)
    print("precision:", p)
    print("recall:", r)
    print("f-score:", f)
    print("confusion matrix:")
    print(mat, '\n')

In [None]:
def train_model(model, criterion, optimizer, scheduler, dataloaders, class_names, dataset_size, num_epochs=25):
    since = time.time()
    dataloaders = dataloaders
    best_model_wts = copy.deepcopy(model.state_dict())
    best_model = copy.deepcopy(model)
    best_acc = 0.0
    best_train_loss = 1000000
    best_loss = 1000000
    loss_vals= defaultdict(lambda: [])
    acc_vals= defaultdict(lambda: [])
    count = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for data in dataloaders[phase]:
                inputs = data['x'].to(device)
                labels = data['y'].to(device).long()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_size[phase]
            epoch_acc = running_corrects.double() / dataset_size[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            loss_vals[phase].append(epoch_loss)
            acc_vals[phase].append(epoch_acc)
            
            # deep copy the model
            if phase == 'test' and count == 0:
                best_acc = epoch_acc
                
                
            if phase == "train":
                if best_train_loss > epoch_loss:
                    best_train_loss = epoch_loss
                    best_loss = epoch_loss
                    print(".........update ...", epoch)
                    best_model_wts = copy.deepcopy(model.state_dict())
                    best_model = copy.deepcopy(model)
                    count = 0
                else:
                    count += 1

            
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    outputs = []
    labels = []
    idx = []
    outputs_tr = []
    labels_tr = []
    idx_tr = []
    
    for data in dataloaders["test"]:
        inputs = data['x'].to(device)
        labels.extend(data['y'].to(device).long())
        print('*****************', len(labels))
        idx.extend(data['id'])
        outputs.extend(best_model(inputs))
        
        
    for data in dataloaders["train"]:
        inputs_tr = data['x'].to(device)
        labels_tr.extend(data['y'].to(device).long())
        idx_tr.extend(data['id'])
        outputs_tr.extend(best_model(inputs_tr))
        
    return best_model, outputs, labels, idx, outputs_tr, labels_tr, idx_tr, loss_vals, acc_vals

In [None]:
def save_values_to_files(file, description, list1, list2 = None, list3 = None):
    f=open(file, "w+")
    f.write(description + '\n')
    
    if list2 == None:
        for d1 in list1:
            if not isinstance(d1, str):
                pass
                d1 = d1.detach().cpu().numpy()
            f.write(str(d1) + '\n')
    
    if list2 != None:
        for d1, d2, d3 in zip(list1, list2, list3):
                d1 = d1.detach().cpu().numpy()  
                if not isinstance(d2, int):
                    d2 = d2.detach().cpu().numpy()
                d3 = d3
                f.write(str(d1[0]) + ', ' + str(d1[1]) + ', ' + str(d2) + ', ' + str(d3) + '\n')        
    f.close()

In [None]:
def save_acc_loss_to_files(file, description, list1, list2):
    f=open(file, "w+")
    f.write(description + '\n')
    count = 0
    for (d1, d2 , d3, d4) in zip(list1['train'], list2['train'], list1['test'], list2['test']):
        d2 = d2.detach().cpu().numpy()
        d4 = d4.detach().cpu().numpy()
        f.write(str(count) + ', ' + str(d1) + ', ' + str(d2) + ', ' + str(d3) + ', ' + str(d4)+ '\n')
        count = count + 1
            
    f.close()

In [None]:
def read_data(path, filename):
    f = open(path + filename, 'r')
    f.readline()
    
    predictions = []
    labels = []
    for line in f:
        print(line)
        splits = line.strip().split(']')
        #print(splits[-1])
        labels.append(int(splits[-1].split(',')[-1]))
        if len(splits[0].split()) > 2:
            preds = [float(splits[0].split()[1]), float(splits[0].split()[-1])] 
        else:
            preds = [float(splits[0].split()[0].split('[')[-1]), float(splits[0].split()[-1])]
        #print(preds, softmax(np.array(preds))[1])
        predictions.append(softmax(np.array(preds))[1])
    # print("=============================")
    return labels, predictions

In [None]:
def create_model(model = 'vit_base_patch16_384', optimizer = 'ADAM', lr = 0.01, step_size = 10):
    model_ft = timm.create_model( model, pretrained=True, num_classes=2)
    model_ft = model_ft.to(device)
    
    for param in model_ft.parameters():
        param.requires_grad = False
    
    for param in model_ft.blocks[11].parameters():
        param.requires_grad = False

    for param in model_ft.head.parameters():
        param.requires_grad = True
    
    criterion = torch.nn.CrossEntropyLoss()
    
    if optimizer == 'ADAM':
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=lr)
    else: 
        optimizer_ft = optim.SGD(model_ft.parameters(), lr, momentum = 0.9)
        
    exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=step_size, gamma=0.1)
    
    return model_ft, optimizer_ft, criterion , exp_lr_scheduler  

In [None]:
def run_the_model(image_path, segment, batch_size, num_epochs, lr, step_size):
    # defining data augmentation and transforms
    optional_transforms = albumentations.Compose([
                          albumentations.HorizontalFlip(),
                          albumentations.VerticalFlip(), 
                          ])
    
    required_transforms = transforms.Compose([
                          transforms.ToPILImage(),
                          transforms.Resize((384, 384)),
                          transforms.ToTensor()
                          ])

    print('Path {}'.format(image_path))
    data = process_dataset_meta_info(image_path, segment)
    print(len(data))
 
    kf = StratifiedKFold(n_splits=5, shuffle= True, random_state = 8)
    #...............
    labels_for_spliter = [d['label'] for d in data]
    print( labels_for_spliter)
    #......................
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(np.arange(len(data)), labels_for_spliter)):
        
        print('Fold {}'.format(fold+1))
        train_sampler = data[train_idx]
        test_sampler = data[test_idx]
        
        np.random.seed(8)
        np.random.shuffle(train_sampler)
        np.random.shuffle(test_sampler)
    
        train_file_paths = []
        train_labels = []
        test_file_paths = []
        test_labels = []
        
        for i, sample in enumerate(train_sampler):
            train_file_paths.append(sample['file_path'])
            train_labels.append(sample['label'])
        
        for i, sample in enumerate(test_sampler):
            test_file_paths.append(sample['file_path'])
            test_labels.append(sample['label'])
            
        train_dataset = ThripDataset(file_paths=train_file_paths, labels=train_labels, is_train = True, optional_transform = optional_transforms, required_transform = required_transforms)
        test_dataset = ThripDataset(file_paths=test_file_paths, labels=test_labels, is_train = False, optional_transform = optional_transforms, required_transform = required_transforms)
        
        if batch_size == None:
            batch_size = round(train_dataset.__len__()/5)
        
        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, num_workers = 0)
        test_dataloader = DataLoader(test_dataset, batch_size = batch_size)
        
        dataloaders = {'train': train_dataloader, 'test': test_dataloader}
        dataset_size = {'train': len(train_dataset), 'test': len(test_dataset)}
        class_names = ['NOTWFT', 'WFT']
        

        model_ft, optimizer_ft, criterion, exp_lr_scheduler = create_model(model = 'vit_base_patch16_384', optimizer = 'ADAM', lr = lr, step_size = step_size)
        model_ft = model_ft.to(device)
        model, outputs, labels, idx, outputs_tr, labels_tr, idx_tr, loss_vals, acc_vals = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloaders, class_names,dataset_size, num_epochs)
        root = path_to_save_values + SEGMENT + '_' + str(fold) + '_' + str(step_size)
        save_values_to_files(root + '_test.txt', 'test_outputs, labels, id', outputs, labels, idx)
        save_values_to_files(root + '_train.txt', 'train_outputs, labels, id', outputs_tr, labels_tr, idx_tr)
        save_acc_loss_to_files(root + '_lossAcc.txt', 'ephoc, train_loss, train_acc, test_loss, test_acc', loss_vals, acc_vals)
        torch.save(model.state_dict(), path_for_models + SEGMENT + "_vit_base_patch16_384")

In [None]:
run_the_model(image_path = path_to_images, segment = SEGMENT, batch_size = 40, num_epochs = 30, lr = 0.01, step_size= 10)