In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import og_mae
from youssef_ganglia_data_loading import HirschImagesDataset
from metrics import mean_iou
from sklearn.metrics import confusion_matrix
import numpy as np
import copy
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from einops import rearrange

In [2]:
checkpoint = torch.load('actual_ganglia_saved_models/ViT_IN1k_ganglia_0.0005_2.pt')
backbone_state_dict = checkpoint['backbone']
linear_state_dict = checkpoint['linear']

def compute_iou(y_pred, y_true):
    smooth = 0.0001
    # ytrue, ypred is a flatten vector
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    current = confusion_matrix(y_true, y_pred, labels=[0, 1])
    # compute mean iou
    intersection = np.diag(current)
    ground_truth_set = current.sum(axis=1)
    predicted_set = current.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection
    IoU = (intersection+smooth) / (union.astype(np.float32)+smooth)
    return np.mean(IoU)

class OriginalModel(nn.Module):
    def __init__(self):
        super(OriginalModel, self).__init__()
        self.model = og_mae.mae_vit_base_patch16_dec512d8b().cuda()
        self.linear = nn.Linear(768, 512).cuda()

    def forward_features(self, img):
        x = self.model.patch_embed(img)
        x = x + self.model.pos_embed[:, 1:, :]

        cls_token = self.model.cls_token + self.model.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.model.blocks:
            x = blk(x)  # (bsz, L, 768)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.linear(x)  # (bsz, L, 512)
        return x

original_model = OriginalModel()
original_model.model.load_state_dict(backbone_state_dict)
original_model.linear.load_state_dict(linear_state_dict)

def evaluate_model_miou(model, data_loader, compute_iou, device='cuda'):
    model.eval()
    
    thresh = 0.5
    all_predictions_test = []
    all_gt_test_combined = []
    all_gt_test_certain = []

    with torch.no_grad():
        for batch in data_loader:
            img, ganglia_potential, ganglia_certain = batch
            img = img.to(device, dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
            ganglia_potential = ganglia_potential.to(device).long().squeeze(dim=1)  # (bsz, H, W)
            ganglia_certain = ganglia_certain.to(device).long().squeeze(dim=1)  # (bsz, H, W)
            combined_mask = (ganglia_certain + ganglia_potential).to(device).long().squeeze(dim=1)

            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                x = model.forward_features(img)
                x = model.linear(x)  # (bsz, L, 512)
                logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16, j=16)  # (bsz, 2, H, W)
                probability = logits.softmax(dim=1)
                predictions = (probability[:, 1, :, :] > thresh).long()

            all_predictions_test.append(predictions.cpu())
            all_gt_test_certain.append(ganglia_certain.cpu())
            all_gt_test_combined.append(combined_mask.cpu())

        all_predictions_test = torch.cat(all_predictions_test, dim=0).numpy()
        all_gt_test_certain = torch.cat(all_gt_test_certain, dim=0).numpy()
        all_gt_test_combined = torch.cat(all_gt_test_combined, dim=0).numpy()

        test_miou_certain = compute_iou(all_predictions_test, all_gt_test_certain)
        test_miou_combined = compute_iou(all_predictions_test, all_gt_test_combined)
    
    return test_miou_certain, test_miou_combined


batch_size = 64

test_dataset = HirschImagesDataset(data_file_path="ganglia_test", do_augmentation=False)
test_loader = DataLoader(test_dataset,
                      batch_size=batch_size,
                      shuffle=False,
                      num_workers=8
                        )


device = 'cuda' if torch.cuda.is_available() else 'cpu'
original_model.to(device)

test_miou_certain, test_miou_combined = evaluate_model_miou(original_model, test_loader, compute_iou, device)
print(f'Test mIoU (certain): {test_miou_certain:.4f}')
print(f'Test mIoU (combined): {test_miou_combined:.4f}')


Test mIoU (certain): 0.7701
Test mIoU (combined): 0.7700
