In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from dataset import data_loader
from torchvision.models import resnet18
from tqdm import tqdm

root_dir = "../data/TB_data/"

# define data loader
def data_loader(root_dir, image_size=(224, 224), batch_size=30, train_dir='training', test_dir='testing', vald_dir='validation'):
    dirs = {'train': os.path.join(root_dir, train_dir),
            'valid': os.path.join(root_dir, vald_dir),
            'test': os.path.join(root_dir, test_dir)}

    data_transform = {
        'train': transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomRotation(20),
            transforms.Resize(image_size),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor()
        ]),

        'valid': transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ]),

        'test': transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])
    }

    image_dataset = {x: ImageFolder(dirs[x], transform=data_transform[x])
                     for x in ('train', 'valid', 'test')}

    data_loaders = {x: DataLoader(image_dataset[x], batch_size=batch_size,
                                   shuffle=True, num_workers=12) for x in ['train']}

    data_loaders['test'] = DataLoader(image_dataset['test'], batch_size=batch_size,
                                       shuffle=False, num_workers=12, drop_last=True)

    data_loaders['valid'] = DataLoader(image_dataset['valid'], batch_size=batch_size,
                                        shuffle=False, num_workers=12, drop_last=True)

    dataset_size = {x: len(image_dataset[x]) for x in ['train', 'valid', 'test']}

    print([f'number of {i} images is {dataset_size[i]}' for i in (dataset_size)])

    class_idx = image_dataset['test'].class_to_idx
    print(f'Classes with index are: {class_idx}')

    class_names = image_dataset['test'].classes
    print(class_names)
    return data_loaders, image_dataset


# train the model  
def train_on_images(model, dataloader, criterion, optimizer, num_epochs=10, device='cuda', savePth='model.pth'):
    model.to(device)
    best_loss = float('inf')
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 10)
        running_loss = 0.0
        for inputs, labels in tqdm(dataloader['train']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(dataloader['train'].dataset)
        print(f"Epoch {epoch+1} - Train Loss: {epoch_loss:.4f}")
        
        # validate
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, labels in dataloader['valid']:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
        val_loss /= len(dataloader['valid'].dataset)
        print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.4f}")
        
        # save best model
        if val_loss < best_loss:
            best_loss = val_loss
            if savePth != None:
                torch.save(model, savePth)
        
        model.train()

    print("Training complete!")
    if savePth != None:
        print(f"Best model saved at {savePth}")
    return model

# def train_on_patches(model, dataloader, criterion, optimizer, vit_model, num_epochs=10, device='cuda', savePth='model.pth'):
#     model.to(device)
#     best_loss = float('inf')
#     for epoch in range(num_epochs):
#         print(f"Epoch {epoch + 1}/{num_epochs}")
#         print("-" * 10)
#         running_loss = 0.0
#         for inputs, labels in tqdm(dataloader['train']):
#             inputs = inputs.to(device)
#             labels = labels.to(device)
#             optimizer.zero_grad()
#             x = vit_model.patch_embed(inputs)
#             x = vit_model.pos_drop(x)
#             for block in vit_model.blocks:
#                 x = block(x)
#             outputs = model(x)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item() * inputs.size(0)
#         epoch_loss = running_loss / len(dataloader['train'].dataset)
#         print(f"Epoch {epoch+1} - Train Loss: {epoch_loss:.4f}")
        
#         # validate
#         val_loss = 0.0
#         model.eval()
#         with torch.no_grad():
#             for inputs, labels in dataloader['valid']:
#                 inputs = inputs.to(device)
#                 labels = labels.to(device)
#                 x = vit_model.patch_embed(inputs)
#                 x = vit_model.pos_drop(x)
#                 for block in vit_model.blocks:
#                     x = block(x)
#                 outputs = model(x)
#                 loss = criterion(outputs, labels)
#                 val_loss += loss.item() * inputs.size(0)
#         val_loss /= len(dataloader['valid'].dataset)
#         print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.4f}")

#         # Release unused memory
#         torch.cuda.empty_cache()
        
#         # save best model
#         if val_loss < best_loss:
#             best_loss = val_loss
#             if savePth != None:
#                 torch.save(model, savePth)
        
#         model.train()

#     print("Training complete!")
#     if savePth != None:
#         print(f"Best model saved at {savePth}")
#     return model

def train_on_patches(model, dataloader, criterion, optimizer, vit_model, clf, num_epochs=10, device='cuda', savePth='model.pth'):
    model.to(device)
    best_loss = float('inf')
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 10)
        running_loss = 0.0
        for inputs, labels in tqdm(dataloader['train']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            x = vit_model.patch_embed(inputs)
            x = vit_model.pos_drop(x)
            for block in range(clf):
                x = vit_model.blocks[block](x)
            outputs = model(x)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(dataloader['train'].dataset)
        print(f"Epoch {epoch+1} - Train Loss: {epoch_loss:.4f}")
        
        # validate
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, labels in dataloader['valid']:
                inputs = inputs.to(device)
                labels = labels.to(device)
                x = vit_model.patch_embed(inputs)
                x = vit_model.pos_drop(x)
                for block in range(clf):
                    x = vit_model.blocks[block](x)
                outputs = model(x)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
        val_loss /= len(dataloader['valid'].dataset)
        print(f"Epoch {epoch+1} - Validation Loss: {val_loss:.4f}")

        # Release unused memory
        torch.cuda.empty_cache()
        
        # save best model
        if val_loss < best_loss:
            best_loss = val_loss
            if savePth != None:
                torch.save(model, savePth)
        
        model.train()

    print("Training complete!")
    if savePth != None:
        print(f"Best model saved at {savePth}")
    return model



# test the model
def test_on_images(model, dataloader, device='cuda'):
    print(type(model))
    model.eval()
    print("Evaluated")
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(dataloader['test']):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))


def test_on_patches(model, dataloader, vit_model, device='cuda'):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(dataloader['test']):
            images = images.to(device)
            labels = labels.to(device)
            x = vit_model.patch_embed(images)
            x = vit_model.pos_drop(x)
            for block in vit_model.blocks:
                x = block(x)
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))


def test_vit(model, dataloader_test):
    """
    This function used to test ViT. 

    Args: 
        model: ViT model
        dataaloader_test: loader for test images 
    return: 
        Avg test accuracy of ViT
    
    """
    test_acc = 0.0
    for images, labels in tqdm(dataloader_test): 
        images = images.cuda()
        labels= labels.cuda()
        with torch.no_grad(): 
            model.eval()
            output = model(images)
            prediction = torch.argmax(output, dim=-1)
            acc = sum(prediction == labels).float().item()/len(labels)
            test_acc += acc
    print(f'Testing accuracy = {(test_acc/len(dataloader_test)):.4f}')

    return round(test_acc/len(dataloader_test),2)


def test_all_classifiers(classifiers_list, dataloader_test, mlp_root_dir, vit_model):
    for clf in range(1, len(classifiers_list) +1):
        acc_avg = 0.0
        print(classifiers_list[clf-1])
        clf_in = torch.load(os.path.join(mlp_root_dir, classifiers_list[clf-1])).cuda()
        clf_in.eval()
        # print(clf_in)
        print(f'Classifier of index {clf-1} has been loaded')

        for images, labels in tqdm(dataloader_test): 
            images = images.cuda()
            labels= labels.cuda()
            # print(images.shape) #torch.Size([30, 3, 224, 224])
            x = vit_model.patch_embed(images)
            # print(x.shape)  #torch.Size([30, 196, 768])
            x = vit_model.pos_drop(x)
            # print(x.shape)  #torch.Size([30, 196, 768])
            for block in range(clf):
                x = vit_model.blocks[block](x)
            # x = x.reshape(30, 3, 224, 224)
            with torch.no_grad():
                # print(x.shape) #torch.Size([30, 196, 768])
                output = clf_in(x)
            predictions = torch.argmax(output, dim=-1)
            acc = torch.sum(predictions == labels).item()/len(labels)
            acc_avg += acc
        print(f'Accuracy of block {clf-1} = {(acc_avg/len(dataloader_test)):.3f}')
    pass

In [2]:
def countParams(model):
    total_params = sum(param.numel() for param in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total Params: ", total_params)
    print("Trainable Params: ", trainable_params)

# Training

In [3]:
# train_on_Patches mlps
import mlp as mlp

criterion = nn.CrossEntropyLoss()

# Load data
loader_, dataset_ = data_loader(root_dir=root_dir)

vit_model = torch.load('../models/vit_base_patch16_224_in21k_test-accuracy_0.96_chest.pth')

# # Train 5 mlp models and save it
for i in range(1):
    print("\n\nmlp ", i, ":")
    mlp_in = mlp.Classifier()
    optimizer_mlp = optim.Adam(mlp_in.parameters(), lr=0.001)
    model_mlp = train_on_patches(mlp_in, loader_, criterion, optimizer_mlp, vit_model=vit_model, clf=5, num_epochs=1, savePth="../ReVIT/models/MyModels2/mlp_block_"+ str(i) +".pth", device='cuda')

['number of train images is 5670', 'number of valid images is 630', 'number of test images is 700']
Classes with index are: {'Normal': 0, 'Tuberculosis': 1}
['Normal', 'Tuberculosis']


mlp  0 :
Epoch 1/1
----------


100%|██████████| 189/189 [02:27<00:00,  1.28it/s]

Epoch 1 - Train Loss: 24.6238





Epoch 1 - Validation Loss: 0.6844
Training complete!
Best model saved at ../ReVIT/models/MyModels2/mlp_block_0.pth


In [None]:
# train_on_images r_mlps
import random_mlp as random_mlp
import random

criterion = nn.CrossEntropyLoss()

# Load datamlp_in
loader_, dataset_ = data_loader(root_dir=root_dir)

# # Train 5 mlp models and save it
for i in range(5):
    print("\n\nmlp ", i, ":")
    r_mlp = random_mlp.Classifier(num_layers=random.randint(4,10))
    optimizer_mlp = optim.Adam(r_mlp.parameters(), lr=0.001)
    model_mlp = train_on_images(r_mlp, loader_, criterion, optimizer_mlp, num_epochs=5, savePth="../ReVIT/models/R_Models/random_mlp"+ str(i) +".pth", device='cuda')

In [13]:
model_dir = '../ReVIT/models/R_Models/'
for i in (os.listdir(model_dir)):
    print(i)    
    countParams(torch.load(model_dir+i))

random_mlp1.pth
Total Params:  627753026
Trainable Params:  627753026
random_mlp0.pth
Total Params:  627712258
Trainable Params:  627712258
random_mlp.pth
Total Params:  316680194
Trainable Params:  316680194
random_mlp4.pth
Total Params:  627753026
Trainable Params:  627753026
random_mlp3.pth
Total Params:  627581442
Trainable Params:  627581442
random_mlp2.pth
Total Params:  627581442
Trainable Params:  627581442


In [5]:
countParams(torch.load('../ReVIT/models/R_Models/random_mlp2.pth'))

Total Params:  627581442
Trainable Params:  627581442


# Testing

In [4]:
# testing_on_images r_mlp modules
model_dir = '../ReVIT/models/R_Models/'
for index, i in enumerate(sorted(os.listdir(model_dir))):
    print("\n------------->", i, index)
    model = torch.load(model_dir+i)
    countParams(model)
    test_on_images(model=model, dataloader=loader_)


-------------> random_mlp0.pth 0
Total Params:  627712258
Trainable Params:  627712258
<class 'random_mlp.Classifier'>
Evaluated


100%|██████████| 23/23 [00:01<00:00, 12.46it/s]


Accuracy of the network on the test images: 87 %

-------------> random_mlp1.pth 1
Total Params:  627753026
Trainable Params:  627753026
<class 'random_mlp.Classifier'>
Evaluated


100%|██████████| 23/23 [00:01<00:00, 12.33it/s]


Accuracy of the network on the test images: 86 %

-------------> random_mlp2.pth 2
Total Params:  627581442
Trainable Params:  627581442
<class 'random_mlp.Classifier'>
Evaluated


100%|██████████| 23/23 [00:02<00:00, 11.39it/s]


Accuracy of the network on the test images: 82 %

-------------> random_mlp3.pth 3
Total Params:  627581442
Trainable Params:  627581442
<class 'random_mlp.Classifier'>
Evaluated


100%|██████████| 23/23 [00:01<00:00, 11.83it/s]


Accuracy of the network on the test images: 81 %

-------------> random_mlp4.pth 4
Total Params:  627753026
Trainable Params:  627753026
<class 'random_mlp.Classifier'>
Evaluated


100%|██████████| 23/23 [00:01<00:00, 12.11it/s]

Accuracy of the network on the test images: 84 %





In [12]:
# Testing_on_patches mlp modules
model_dir = '../ReVIT/models/MyModels2/'
loader_, dataset_ = data_loader(root_dir=root_dir)
vit_model = torch.load('../models/vit_base_patch16_224_in21k_test-accuracy_0.96_chest.pth')
vit_model.eval()
countParams(vit_model)
for index, i in enumerate(sorted(os.listdir(model_dir))):
    print("\n------------->", i, index)
    model = torch.load(model_dir+i)
    countParams(model)
    test_on_patches(model=model, dataloader=loader_, vit_model=vit_model)

['number of train images is 5670', 'number of valid images is 630', 'number of test images is 700']
Classes with index are: {'Normal': 0, 'Tuberculosis': 1}
['Normal', 'Tuberculosis']
Total Params:  85800194
Trainable Params:  85800194

-------------> mlp_block_0.pth 0
Total Params:  625219970
Trainable Params:  625219970


100%|██████████| 23/23 [00:06<00:00,  3.82it/s]

Accuracy of the network on the test images: 49 %





#### Testing the mlp modules through the first 5 blocks of ViT

In [8]:
countParams(torch.load('../models/vit_base_patch16_224_in21k_test-accuracy_0.96_chest.pth').to('cuda'))

Total Params:  85800194
Trainable Params:  85800194


In [8]:
root_dir = "../data/TB_data/"
loader_, dataset_ = data_loader(root_dir=root_dir)
vit_model = torch.load('../models/vit_base_patch16_224_in21k_test-accuracy_0.96_chest.pth')
countParams(vit_model)
test_vit(vit_model, loader_['test'])

['number of train images is 5670', 'number of valid images is 630', 'number of test images is 700']
Classes with index are: {'Normal': 0, 'Tuberculosis': 1}
['Normal', 'Tuberculosis']
Total Params:  85800194
Trainable Params:  85800194


100%|██████████| 23/23 [00:05<00:00,  3.84it/s]

Testing accuracy = 0.9638





0.96

#### Testing pretrained mlp modules (where their input is ViT outputted patches)

In [29]:
import os
classifiers_list = sorted(os.listdir('../models/MLP_new_chest'))
# print(classifiers_list)
test_classifiers(classifiers_list=classifiers_list, dataloader_test=loader_['test'], mlp_root_dir='../models/MLP_new_chest', vit_model=vit_model)

block_0_classifier_0.94test_0.98train.pth
Classifier of index 0 has been loaded


100%|██████████| 23/23 [00:01<00:00, 11.60it/s]


Accuracy of block 0 = 0.945
block_1_classifier_0.93test_0.99train.pth
Classifier of index 1 has been loaded


100%|██████████| 23/23 [00:02<00:00, 10.64it/s]


Accuracy of block 1 = 0.933
block_2_classifier_0.94test_0.99train.pth
Classifier of index 2 has been loaded


100%|██████████| 23/23 [00:02<00:00,  9.86it/s]


Accuracy of block 2 = 0.941
block_3_classifier_0.93test_0.99train.pth
Classifier of index 3 has been loaded


100%|██████████| 23/23 [00:02<00:00,  9.12it/s]


Accuracy of block 3 = 0.933
block_4_classifier_0.92test_1.00train.pth
Classifier of index 4 has been loaded


100%|██████████| 23/23 [00:02<00:00,  8.78it/s]

Accuracy of block 4 = 0.923





#### Testing trained mlp modules (where their input is ViT outputted patches)

In [10]:
import os
classifiers_list = sorted(os.listdir('../ReVIT/models/MyModels2'))
# print(classifiers_list)
test_classifiers(classifiers_list=classifiers_list, dataloader_test=loader_['test'], mlp_root_dir='../ReVIT/models/MyModels2', vit_model=vit_model)

mlp_block_0.pth
Classifier of index 0 has been loaded


100%|██████████| 23/23 [00:02<00:00, 10.45it/s]

Accuracy of block 0 = 0.507



