In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, utils
from einops import rearrange
import os
from torchvision.transforms import (
    RandomHorizontalFlip,
    RandomRotation,
    RandomVerticalFlip,
    RandomApply,
    InterpolationMode,
    RandomCrop,
    RandomResizedCrop,
    CenterCrop
)
import math
import csv
#from histo_vit import vit_small
import random
from torchvision.transforms.functional import hflip
from torchvision.transforms.functional import vflip
#import segmenter
import og_mae
from youssef_plexus_data_loading import HirschImagesDataset
from metrics import mean_iou
from sklearn.metrics import confusion_matrix
import numpy as np
import copy
import pandas as pd

In [2]:
file_path = torch.load("actual_plexus_saved_models/ViT_IN1k_plexus_0.001.pt")

In [3]:
file_path

{'backbone': OrderedDict([('cls_token',
               tensor([[[ 1.4092e-02, -3.1947e-03,  1.4926e-03, -1.3980e-02, -1.1262e-03,
                         -1.8892e-01, -2.8300e-03, -5.8989e-03, -2.4389e-02,  1.9444e-03,
                         -4.1007e-03,  3.3894e-03,  3.8824e-03, -1.2017e-02, -9.1470e-03,
                         -2.3664e-01,  3.8039e-02, -2.8332e-01,  8.3252e-02, -4.2101e-02,
                         -4.9465e-02,  4.5167e-02,  9.4976e-02, -2.5377e-02, -1.2333e-03,
                         -7.9862e-03, -1.0309e-02, -1.0979e-02, -1.4768e-02, -1.0680e-02,
                         -7.4596e-03, -6.1222e-03,  1.3337e-02, -3.3151e-02, -1.5566e-02,
                         -1.9576e-03,  3.8417e-03, -2.2146e-02, -8.2715e-03, -1.0323e-02,
                         -5.9301e-03,  1.4620e-02,  8.0951e-04, -2.2796e-02, -1.7509e-02,
                         -4.0207e-03, -1.1738e-03,  1.0374e-01,  3.4161e-03, -2.0940e-03,
                         -2.8637e-02, -5.2454e-03,  1.0488e-

In [9]:
def augment_image_with_map(_img, _map):
    side_outer = 512
    angle = torch.randint(low=0, high=90, size=(1,)).item()
    
    aug1 = torch.nn.Sequential(RandomRotation((angle, angle)))
    
    side_inner = side_outer / (math.cos(math.radians(angle)) + math.sin(math.radians(angle)))
    #print(f"The new h and w are: {side_inner}")
    
    state = torch.get_rng_state()
    _img = aug1(_img)

    torch.set_rng_state(state)
    _map = aug1(_map)
    
    center_x = side_outer // 2
    center_y = side_outer // 2

    half_width = side_inner // 2
    half_height = side_inner // 2 

    start_x = round(center_x - half_width)
    end_x = round(center_x + half_width)
    start_y = round(center_y - half_height)
    end_y = round(center_y + half_height)

    _img = _img[:, start_y:end_y, start_x:end_x]
    _map = _map[:, start_y:end_y, start_x:end_x]
    
    aug2 = torch.nn.Sequential(
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    RandomResizedCrop(size=(224, 224), scale=(0.5, 2.0)))
    
    state = torch.get_rng_state()
    _img = aug2(_img)

    torch.set_rng_state(state)
    _map = aug2(_map)
    
    
    return _img, _map


In [10]:
def adjust_learning_rate(epoch, sched_config):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < sched_config['warmup_epochs']:
        lr = sched_config['lr'] * epoch / sched_config['warmup_epochs']
    else:
        lr = sched_config['min_lr'] + (sched_config['lr'] - sched_config['min_lr']) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - sched_config['warmup_epochs']) / (sched_config['epochs'] - sched_config['warmup_epochs'])))
    return lr


In [11]:
def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

In [12]:
def compute_iou(y_pred, y_true):
    smooth = 0.0001
    # ytrue, ypred is a flatten vector
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    current = confusion_matrix(y_true, y_pred, labels=[0, 1])
    # compute mean iou
    intersection = np.diag(current)
    ground_truth_set = current.sum(axis=1)
    predicted_set = current.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection
    IoU = (intersection+smooth) / (union.astype(np.float32)+smooth)
    return np.mean(IoU)

In [15]:
# learning_rates = [2e-3, 5e-3, 8e-3]
learning_rates = [1e-4, 2e-4, 5e-4, 1e-3]
best_lr = None
best_model_state = None
best_linear_layer = None
use_mixup = False
lambda_values = [0.2, 0.5, 0.8]

columns = ['Learning Rate', 'Epoch', 'Val mIoU', 'Test mIoU']
model_info_df = pd.DataFrame(columns=columns)


print(f'Learning Rate: {base_lr}')

batch_size = 64

train_dataset = HirschImagesDataset(data_file_path="plexus_train", do_augmentation=True)
train_loader = DataLoader(train_dataset,
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=8
                         )

val_dataset = HirschImagesDataset(data_file_path="plexus_val", do_augmentation=False)
val_loader = DataLoader(val_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=8
                       )

test_dataset = HirschImagesDataset(data_file_path="plexus_test", do_augmentation=False)
test_loader = DataLoader(test_dataset,
                      batch_size=batch_size,
                      shuffle=False,
                      num_workers=8
                        )

best_val_miou = 0
#base_lr = 1e-4
learning_rate = base_lr * batch_size / 256 # added

model = og_mae.mae_vit_base_patch16_dec512d8b().cuda()
model.load_state_dict(torch.load('mae_visualize_vit_base.pth')['model'])
linear = nn.Linear(768, 512).cuda()

# optimizer
backbone_params = model.parameters()
linear_params = linear.parameters()
# head_params = seg_head.parameters()
opt = torch.optim.AdamW([{'params': backbone_params}, {'params': linear_params}], lr=learning_rate)
loss_function = torch.nn.CrossEntropyLoss()

# Prep LR stepping
epochs = 50
multiplier = 1
backbone_config = {'lr': learning_rate,
                   'warmup_epochs': 5,
                   'min_lr': 0,
                   'epochs': epochs}

head_config = {'lr': multiplier * learning_rate,
               'warmup_epochs': 5,
               'min_lr': 0,
               'epochs': epochs}
num_down = 0
#     for epoch in range(epochs):
#         if num_down >= 20:
#             break

#         opt.param_groups[0]['lr'] = adjust_learning_rate(epoch, backbone_config)
#         opt.param_groups[1]['lr'] = adjust_learning_rate(epoch, head_config)

#         current_lr_backbone = opt.param_groups[0]['lr']  # confirm
#         current_lr_head = opt.param_groups[1]['lr']  # confirm

#         train_losses = []

#         model = model.train()
#         # seg_head = seg_head.train()
#         linear = linear.train()
#         for batch in train_loader:
#             img, plexus = batch  # load from batch

#             # Q: I shouldn't augment again right?

#             img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
#             plexus = plexus.cuda().long().squeeze(dim=1)  # (bsz, H, W)

#             # Mix the inputs and the labels here
#             # 1st Step: flip the order of the images
#             if use_mixup:
#                 img_flipped = img.flip(0)
#                 img = (1 - lam) * img_flipped + lam * img

#             with torch.cuda.amp.autocast(dtype=torch.bfloat16):
#                 x = model.patch_embed(img)
#                 x = x + model.pos_embed[:, 1:, :]

#                 cls_token = model.cls_token + model.pos_embed[:, :1, :]
#                 cls_tokens = cls_token.expand(x.shape[0], -1, -1)
#                 x = torch.cat((cls_tokens, x), dim=1)

#                 # apply Transformer blocks
#                 for blk in model.blocks:
#                     x = blk(x)  # (bsz, L, 768)

#                 x = linear(x)  # (bsz, L, 512)
#                 logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16, j=16)  # (bsz, 2, H, W)
#                 # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)

# #             print(logits.shape, plexus.shape)
#             if use_mixup:
#                 loss_original = loss_function(logits, plexus)
#                 loss_flipped = loss_function(logits, plexus.flip(0))
#                 loss = (1 - lam) * loss_flipped + lam * loss_original
#             else:
#                 loss = loss_function(logits, plexus)

#             loss.backward()
#             opt.step()
#             opt.zero_grad()
#             train_losses.append(loss.item())

#         val_losses = []
thresh = 0.5
all_predictions_val  = []
all_gt_val = []
model.eval()
for batch in val_loader:
    img, plexus = batch  # load from batch
    img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
    plexus = plexus.cuda().long().squeeze(dim=1)  # (bsz, H, W)

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            x = model.patch_embed(img)
            x = x + model.pos_embed[:, 1:, :]

            cls_token = model.cls_token + model.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)

            # apply Transformer blocks
            for blk in model.blocks:
                x = blk(x)  # (bsz, L, 768)

            x = linear(x)  # (bsz, L, 512)
            logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16,
                               j=16)  # (bsz, 2, H, W)
            probability = logits.softmax(dim=1)
            predictions = (probability[:,1,:, :] > thresh).long()
#                     predictions  = logits.argmax(dim=1)  # (bza, H, W)
    all_predictions_val.append(predictions.cpu())
    all_gt_val.append(plexus.cpu())
            # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)

#             loss = loss_function(logits, plexus)
#             val_losses.append(loss.item())
all_predictions_val = torch.cat(all_predictions_val, dim=0).numpy()
all_gt_val = torch.cat(all_gt_val, dim=0).numpy()

val_miou = compute_iou(all_predictions_val, all_gt_val)

#         val_miou = mean_iou(results=all_predictions_val,
#                     gt_seg_maps=all_gt_val,
#                     num_classes=2,
#                     ignore_index=-1)

#         test_losses = []
thresh = 0.5
all_predictions_test  = []
all_gt_test = []
model.eval()
for batch in test_loader:
    img, plexus = batch  # load from batch
    img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
    plexus = plexus.cuda().long().squeeze(dim=1)  # (bsz, H, W)

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            x = model.patch_embed(img)
            x = x + model.pos_embed[:, 1:, :]

            cls_token = model.cls_token + model.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)

            # apply Transformer blocks
            for blk in model.blocks:
                x = blk(x)  # (bsz, L, 768)

            x = linear(x)  # (bsz, L, 512)
            logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16,
                               j=16)  # (bsz, 2, H, W)
            probability = logits.softmax(dim=1)
            predictions = (probability[:,1,:, :] > thresh).long()
#                     predictions  = logits.argmax(dim=1)  # (bza, H, W)
    all_predictions_test.append(predictions.cpu())
    all_gt_test.append(plexus.cpu())
            # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)

#             loss = loss_function(logits, plexus)
#             test_losses.append(loss.item())
all_predictions_test = torch.cat(all_predictions_test, dim=0).numpy()
all_gt_test = torch.cat(all_gt_test, dim=0).numpy()

test_miou = compute_iou(all_predictions_test, all_gt_test)

#         test_miou = mean_iou(results=all_predictions_test,
#                     gt_seg_maps=all_gt_test,
#                     num_classes=2,
#                     ignore_index=-1)

#         val_losses = torch.Tensor(val_losses).mean().item()
#         test_losses = torch.Tensor(test_losses).mean().item()
print(f'Epoch: {epoch}, Val mIoU: {val_miou}, Test mIoU: {test_miou}, Base LR: {base_lr}')

#         avg_val_miou = val_miou['IoU'].mean()

#         if val_miou > best_val_miou:
#             best_val_miou = val_miou
#             best_lr = base_lr
# #             best_model_state = copy.deepcopy(model.state_dict()) 
# #             best_linear_layer = copy.deepcopy(linear.state_dict()) 
#             print(f'Best Learning Rate: {best_lr}')
#             print(f'SAVING')
#             torch.save(obj={'backbone': model.state_dict(),
#                             'linear': linear.state_dict()},
#                        f=f'actual_plexus_saved_models/ViT_IN1k_plexus_{base_lr}.pt')

#             d = {'Learning Rate': base_lr, 'Epoch': epoch, 'Train Loss': train_losses, 'Val mIoU': val_miou, 
#                  'Test mIoU': test_miou}
#             model_info_df = pd.concat([model_info_df, pd.DataFrame([d])], ignore_index=True)

#             num_down = 0
#         else:
#             num_down += 1

#         # write to logs
#         with open(f'ViT_IN1k_plexus_logs_{base_lr}.csv', 'a', errors="ignore") as out_file:
#             csv_writer = csv.writer(out_file, delimiter=',', lineterminator='\n')
#             csv_writer.writerow([epoch, train_losses, val_miou, test_miou, best_val_miou, current_lr_backbone, current_lr_head, base_lr])


Learning Rate: 0.0001
Epoch: 0, Val mIoU: 0.25194102614667485, Test mIoU: 0.2518424799796887, Base LR: 0.0001
