In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import og_mae
from youssef_plexus_data_loading import HirschImagesDataset
from metrics import mean_iou
from sklearn.metrics import confusion_matrix
import numpy as np
import copy
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from einops import rearrange

In [16]:
checkpoint = torch.load('actual_plexus_saved_models/ViT_IN1k_plexus_0.0002.pt')
backbone_state_dict = checkpoint['backbone']
linear_state_dict = checkpoint['linear']

In [17]:
def compute_iou(y_pred, y_true):
    smooth = 0.0001
    # ytrue, ypred is a flatten vector
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    current = confusion_matrix(y_true, y_pred, labels=[0, 1])
    # compute mean iou
    intersection = np.diag(current)
    ground_truth_set = current.sum(axis=1)
    predicted_set = current.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection
    IoU = (intersection+smooth) / (union.astype(np.float32)+smooth)
    return np.mean(IoU)

In [18]:
def compute_mean_inclusion(y_pred, y_true):
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    
    current = confusion_matrix(y_true, y_pred, labels=[0, 1])
    print(current)
    
    ground_truth_area = current[1, 1] + current[1, 0]
    print(ground_truth_area)
    
    included_area = current[1, 1]
    print(included_area)
    
    inclusion_percentage = (included_area / ground_truth_area) * 100 if ground_truth_area > 0 else 0
    
    return inclusion_percentage

In [19]:
class OriginalModel(nn.Module):
    def __init__(self):
        super(OriginalModel, self).__init__()
        self.model = og_mae.mae_vit_base_patch16_dec512d8b().cuda()
        self.linear = nn.Linear(768, 512).cuda()

    def forward_features(self, img):
        x = self.model.patch_embed(img)
        x = x + self.model.pos_embed[:, 1:, :]

        cls_token = self.model.cls_token + self.model.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.model.blocks:
            x = blk(x)  # (bsz, L, 768)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.linear(x)  # (bsz, L, 512)
        return x

original_model = OriginalModel()
original_model.model.load_state_dict(backbone_state_dict)
original_model.linear.load_state_dict(linear_state_dict)

<All keys matched successfully>

In [20]:
def evaluate_model(model, data_loader, compute_mean_inclusion, device='cuda'):
    model.eval()
    
    thresh = 0.5
    all_predictions_test = []
    all_gt_test = []

    with torch.no_grad():
        for batch in data_loader:
            img, plexus = batch  # load from batch
            img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
            plexus = plexus.cuda().long().squeeze(dim=1)  # (bsz, H, W)

            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                x = model.forward_features(img)
                x = model.linear(x)  # (bsz, L, 512)
                logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16, j=16)  # (bsz, 2, H, W)
                probability = logits.softmax(dim=1)
                predictions = (probability[:, 1, :, :] > thresh).long()

            all_predictions_test.append(predictions.cpu())
            all_gt_test.append(plexus.cpu())

        all_predictions_test = torch.cat(all_predictions_test, dim=0).numpy()
        all_gt_test = torch.cat(all_gt_test, dim=0).numpy()

        test_miou = compute_mean_inclusion(all_predictions_test, all_gt_test)
    
    return test_miou


In [21]:
batch_size = 64

val_dataset = HirschImagesDataset(data_file_path="plexus_val", do_augmentation=False)
val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=8
                       )

test_dataset = HirschImagesDataset(data_file_path="plexus_test", do_augmentation=False)
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=8
                        )


device = 'cuda' if torch.cuda.is_available() else 'cpu'
original_model.to(device)

val_miou = evaluate_model(original_model, val_loader, compute_mean_inclusion, device)
test_miou = evaluate_model(original_model, test_loader, compute_mean_inclusion, device)
print(f'Val Mean Inclusion: {val_miou:.5f}')
print(f'Test Mean Inclusion: {test_miou:.5f}')


[[350159699    171798]
 [   309420    591083]]
900503
591083
[[350205166    308706]
 [   238973    479155]]
718128
479155
Val Mean Inclusion: 65.63920
Test Mean Inclusion: 66.72278
