In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, utils
from einops import rearrange
import os
from torchvision.transforms import (
    RandomHorizontalFlip,
    RandomRotation,
    RandomVerticalFlip,
    RandomApply,
    InterpolationMode,
    RandomCrop,
    RandomResizedCrop
)
import math
import csv
#from histo_vit import vit_small
import random
from torchvision.transforms.functional import hflip
from torchvision.transforms.functional import vflip
#import segmenter
import og_mae
from youssef_data_loading import HirschImagesDataset

In [10]:
def augment_image_with_map(_img, _map):
    side_outer = 512
    angle = torch.randint(low=0, high=90, size=(1,)).item()
    
    aug1 = torch.nn.Sequential(RandomRotation((angle, angle)))
    
    side_inner = side_outer / (math.cos(math.radians(angle)) + math.sin(math.radians(angle)))
    print(f"The new h and w are: {side_inner}")
    
    state = torch.get_rng_state()
    _img = aug1(_img)

    torch.set_rng_state(state)
    _map = aug1(_map)
    
    center_x = side_outer // 2
    center_y = side_outer // 2

    half_width = side_inner // 2
    half_height = side_inner // 2 

    start_x = round(center_x - half_width)
    end_x = round(center_x + half_width)
    start_y = round(center_y - half_height)
    end_y = round(center_y + half_height)

    _img = _img[:, start_y:end_y, start_x:end_x]
    _map = _map[:, start_y:end_y, start_x:end_x]
    
    aug2 = torch.nn.Sequential(
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    RandomResizedCrop(size=(224, 224), scale=(0.5, 2.0)))
    
    state = torch.get_rng_state()
    _img = aug2(_img)

    torch.set_rng_state(state)
    _map = aug2(_map)
    
    
    return _img, _map


In [11]:
def adjust_learning_rate(epoch, sched_config):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < sched_config['warmup_epochs']:
        lr = sched_config['lr'] * epoch / sched_config['warmup_epochs']
    else:
        lr = sched_config['min_lr'] + (sched_config['lr'] - sched_config['min_lr']) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - sched_config['warmup_epochs']) / (sched_config['epochs'] - sched_config['warmup_epochs'])))
    return lr


In [12]:
def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

In [13]:
file_names = ['S14-580.pt',
              'S00-1910.pt',
              'S02-410.pt',
              'S02-484.pt',
              'S03-2391.pt',
              'S01-18.pt',
              "S03-3178 D2.pt",
              "S03-3178 D3.pt",
              "S03-3178 D4.pt",
              'S04-52.pt',
              'S04-910.pt',
              'S07-1808.pt',
              'S08-2215.pt',
              'S09-2723.pt',
              'S04-1840.pt',
              'S07-1465.pt',
              'S14-1715.pt',
              'S09-2909.pt',
              'S14-3414.pt',
              'S14-2038.pt',
              'S15-1442.pt',
              'S15-1518.pt',
              'S16-567.pt',
              "S16-1197 B1.pt",
              'S11-1760.pt',
              'S16-1467.pt',
              "S16-1197 B3.pt",
              "S16-1197 B2.pt",
              'S97-2054.pt',
              'S16-1415.pt']

In [15]:
for fold_num in range(5):

#     validation_files = file_names[(6*fold_num):(6*fold_num+6)]

#     train_imgs = []
#     train_labels = []
#     val_imgs = []
#     val_labels = []

#     data_paths = os.listdir('muscle_5x_normed')
#     for i_path, data_path in enumerate(data_paths):
#         torch_obj = torch.load(f'muscle_5x_normed/{data_path}')

#         if data_path in validation_files:
#             val_imgs.append(torch_obj['imgs'])
#             val_labels.append(torch_obj['muscles'])
#         else:
#             train_imgs.append(torch_obj['imgs'])
#             train_labels.append(torch_obj['muscles'])

#     # 512 now not 256
#     train_imgs = torch.cat(train_imgs, dim=0)  # (48_000, 3, 256, 256) 
#     train_labels = torch.cat(train_labels, dim=0)  # (48_000, 1, 256, 256)
#     val_imgs = torch.cat(val_imgs, dim=0)  # (12_000, 3, 256, 256)
#     val_labels = torch.cat(val_labels, dim=0)  # (12_000, 1, 256, 256)

    print(f'Starting fold: {fold_num}')
#     print(train_imgs.shape, train_labels.shape)
#     print(val_imgs.shape, val_labels.shape)

    batch_size = 100
    
    train_dataset = HirschImagesDataset(data_file_path="muscle_train", do_augmentation=True)
    train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8
                             )
    
    val_dataset = HirschImagesDataset(data_file_path="muscle_val", do_augmentation=False)
    val_loader = DataLoader(val_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=8
                           )
    
    test_dataset = HirschImagesDataset(data_file_path="muscle_test", do_augmentation=False)
    test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=8
                            )
    
#     train_loader = DataLoader(TensorDataset(train_imgs, train_labels), batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(TensorDataset(val_imgs, val_labels), batch_size=batch_size, shuffle=False)

#     del train_imgs
#     del train_labels
#     del val_imgs
#     del val_labels

    best_val_loss = 10
    base_lr = 1e-4
    learning_rate = base_lr * batch_size / 256 # 256 before

    model = og_mae.mae_vit_base_patch16_dec512d8b().cuda()
    model.load_state_dict(torch.load('mae_visualize_vit_base.pth')['model'])
    linear = nn.Linear(768, 512).cuda()

    # decoder = segmenter.MaskTransformer(n_cls=2,
    #                                     patch_size=16,
    #                                     d_encoder=384,
    #                                     n_layers=2,
    #                                     n_heads=12,
    #                                     d_model=384,
    #                                     d_ff=1536,
    #                                     drop_path_rate=0,
    #                                     dropout=0)
    # seg_head = segmenter.Segmenter(decoder=decoder, n_cls=2).cuda()

    # optimizer
    backbone_params = model.parameters()
    linear_params = linear.parameters()
    # head_params = seg_head.parameters()
    opt = torch.optim.AdamW([{'params': backbone_params}, {'params': linear_params}], lr=learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()

    # Prep LR stepping
    epochs = 50
    multiplier = 1
    backbone_config = {'lr': learning_rate,
                       'warmup_epochs': 5,
                       'min_lr': 0,
                       'epochs': epochs}

    head_config = {'lr': multiplier * learning_rate,
                   'warmup_epochs': 5,
                   'min_lr': 0,
                   'epochs': epochs}
    num_down = 0
    for epoch in range(epochs):
        if num_down >= 20:
            break

        opt.param_groups[0]['lr'] = adjust_learning_rate(epoch, backbone_config)
        opt.param_groups[1]['lr'] = adjust_learning_rate(epoch, head_config)

        current_lr_backbone = opt.param_groups[0]['lr']  # confirm
        current_lr_head = opt.param_groups[1]['lr']  # confirm

        train_losses = []

        model = model.train()
        # seg_head = seg_head.train()
        linear = linear.train()
        for batch in train_loader:
            img, plexus = batch  # load from batch
            
            # Q: I shouldn't augment again right?
            #img, plexus = augment_image_with_map(img.cuda(), plexus.cuda())  # perform data augmentation

            img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
            plexus = plexus.cuda().long()  # (bsz, H, W)

            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                x = model.patch_embed(img)
                x = x + model.pos_embed[:, 1:, :]

                cls_token = model.cls_token + model.pos_embed[:, :1, :]
                cls_tokens = cls_token.expand(x.shape[0], -1, -1)
                x = torch.cat((cls_tokens, x), dim=1)

                # apply Transformer blocks
                for blk in model.blocks:
                    x = blk(x)  # (bsz, L, 768)

                x = linear(x)  # (bsz, L, 512)
                logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16, j=16)  # (bsz, 2, H, W)
                # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)


            loss = loss_function(logits, plexus)
            loss.backward()
            opt.step()
            opt.zero_grad()
            train_losses.append(loss.item())

        val_losses = []
        model.eval()
        for batch in val_loader:
            img, plexus = batch  # load from batch
            img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
            plexus = plexus.cuda().long()  # (bsz, H, W)

            with torch.no_grad():
                with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                    x = model.patch_embed(img)
                    x = x + model.pos_embed[:, 1:, :]

                    cls_token = model.cls_token + model.pos_embed[:, :1, :]
                    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
                    x = torch.cat((cls_tokens, x), dim=1)

                    # apply Transformer blocks
                    for blk in model.blocks:
                        x = blk(x)  # (bsz, L, 768)

                    x = linear(x)  # (bsz, L, 512)
                    logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16,
                                       j=16)  # (bsz, 2, H, W)
                    # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)

            loss = loss_function(logits, plexus)
            val_losses.append(loss.item())
            
        test_losses = []
        model.eval()
        for batch in test_loader:
            img, plexus = batch  # load from batch
            img = img.cuda().to(dtype=torch.bfloat16) / 255  # (bsz, 3, H, W)
            plexus = plexus.cuda().long()  # (bsz, H, W)

            with torch.no_grad():
                with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                    x = model.patch_embed(img)
                    x = x + model.pos_embed[:, 1:, :]

                    cls_token = model.cls_token + model.pos_embed[:, :1, :]
                    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
                    x = torch.cat((cls_tokens, x), dim=1)

                    # apply Transformer blocks
                    for blk in model.blocks:
                        x = blk(x)  # (bsz, L, 768)

                    x = linear(x)  # (bsz, L, 512)
                    logits = rearrange(x[:, 1:, :], 'b (h w) (c i j) -> b c (h i) (w j)', h=14, w=14, c=2, i=16,
                                       j=16)  # (bsz, 2, H, W)
                    # logits = seg_head(features=x[:, 1:, :], HW_input=224, HW_target=224)  # (bsz, 2, H, W)

            loss = loss_function(logits, plexus)
            test_losses.append(loss.item())

        train_losses = torch.Tensor(train_losses).mean().item()
        val_losses = torch.Tensor(val_losses).mean().item()
        test_losses = torch.Tensor(test_losses).mean().item()
        print(f'Epoch: {epoch}, Train Loss: {train_losses}, Val Loss: {val_losses}, Test Loss: {test_losses}, LR Backbone: {current_lr_backbone}, LR Head: {current_lr_head},')

        if best_val_loss > val_losses:
            best_val_loss = val_losses
            print(f'SAVING')
            # torch.save(obj={'backbone': model.state_dict(),
            #                 'head': seg_head.state_dict()},
            #            f=f'saved_models/ViT_HIPT_{fold_num}_muscle_5x_{base_lr}.pt')
            torch.save(obj={'backbone': model.state_dict(),
                            'linear': linear.state_dict()},
                       f=f'saved_models/ViT_IN1k_{fold_num}_muscle_5x_{base_lr}.pt')
            num_down = 0
        else:
            num_down += 1

        # write to logs
        with open(f'ViT_IN1k_muscle_logs_5x_{base_lr}.csv', 'a', errors="ignore") as out_file:
            csv_writer = csv.writer(out_file, delimiter=',', lineterminator='\n')
            csv_writer.writerow([epoch, train_losses, val_losses, test_losses, best_val_loss, current_lr_backbone, current_lr_head, base_lr, fold_num])
            

Starting fold: 0
The new h and w are: 479.7059444517019The new h and w are: 426.90818316598495The new h and w are: 503.2929803092664The new h and w are: 441.96761088688356The new h and w are: 410.0341570616378The new h and w are: 503.2929803092664The new h and w are: 373.12198968145515






The new h and w are: 512.0
The new h and w are: 382.89958686219796
The new h and w are: 380.66998728722893The new h and w are: 453.32153080204694

The new h and w are: 362.9227335058999
The new h and w are: 364.7575192033416The new h and w are: 364.03288218551364The new h and w are: 370.12683126341176The new h and w are: 364.75751920334164



The new h and w are: 422.36640580078557
The new h and w are: 367.62370204758565
The new h and w are: 396.3006645577891
The new h and w are: 441.96761088688356
The new h and w are: 387.79591098099945The new h and w are: 459.4336692874594

The new h and w are: 422.36640580078557
The new h and w are: 385.27350748452477
The new h and w are: 453.32153080204694
The 

The new h and w are: 362.9227335058999
The new h and w are: 365.5966332708172
The new h and w are: 365.5966332708172
The new h and w are: 387.79591098099945
The new h and w are: 426.90818316598495
The new h and w are: 399.465476459458
The new h and w are: 472.60792140656827The new h and w are: 374.8100134752651

The new h and w are: 367.62370204758565
The new h and w are: 465.85672437745495
The new h and w are: 418.0462494349958
The new h and w are: 465.85672437745495The new h and w are: 363.4216008788375

The new h and w are: 376.6286262078867
The new h and w are: 362.2593505759632
The new h and w are: 387.79591098099945
The new h and w are: 459.4336692874594
The new h and w are: 367.62370204758565
The new h and w are: 368.8148393527191
The new h and w are: 436.69754417973104
The new h and w are: 362.03867196751236
The new h and w are: 512.0
The new h and w are: 459.4336692874594
The new h and w are: 399.465476459458The new h and w are: 367.62370204758565

The new h and w are: 362.922

The new h and w are: 380.66998728722893The new h and w are: 362.9227335058999
The new h and w are: 422.36640580078557

The new h and w are: 413.93838832150976The new h and w are: 371.5617762203221

The new h and w are: 479.705944451702
The new h and w are: 487.17108638271094
The new h and w are: 382.89958686219796
The new h and w are: 406.3255005874387
The new h and w are: 373.12198968145515
The new h and w are: 367.62370204758565The new h and w are: 362.53551429007956
The new h and w are: 406.3255005874386

The new h and w are: 364.75751920334164
The new h and w are: 367.62370204758565
The new h and w are: 362.53551429007956
The new h and w are: 371.5617762203221The new h and w are: 387.79591098099945

The new h and w are: 376.6286262078867
The new h and w are: 362.0938206102321
The new h and w are: 373.12198968145515
The new h and w are: 368.8148393527191
The new h and w are: 370.12683126341176
The new h and w are: 465.85672437745495
The new h and w are: 465.85672437745495
The new h 

The new h and w are: 393.30446831393897The new h and w are: 447.50440903559274

The new h and w are: 371.5617762203221
The new h and w are: 365.5966332708172
The new h and w are: 362.2593505759632
The new h and w are: 366.5515295537089The new h and w are: 363.4216008788375

The new h and w are: 366.5515295537089
The new h and w are: 503.2929803092664
The new h and w are: 487.17108638271094
The new h and w are: 465.85672437745495
The new h and w are: 472.60792140656827
The new h and w are: 362.03867196751236
The new h and w are: 503.2929803092664
The new h and w are: 363.42160087883747
The new h and w are: 382.89958686219796
The new h and w are: 387.79591098099945
The new h and w are: 374.8100134752651The new h and w are: 465.85672437745495

The new h and w are: 364.03288218551364
The new h and w are: 418.0462494349958
The new h and w are: 371.5617762203221
The new h and w are: 368.8148393527191
The new h and w are: 363.42160087883747
The new h and w are: 368.8148393527191
The new h and

The new h and w are: 418.0462494349958
The new h and w are: 431.6816230411364The new h and w are: 390.471285909815The new h and w are: 495.0254181608458


The new h and w are: 406.3255005874386
The new h and w are: 368.8148393527191The new h and w are: 367.62370204758565
The new h and w are: 479.7059444517019

The new h and w are: 376.6286262078867The new h and w are: 380.66998728722893

The new h and w are: 367.62370204758565
The new h and w are: 367.62370204758565
The new h and w are: 362.2593505759632The new h and w are: 447.50440903559274The new h and w are: 441.96761088688356


The new h and w are: 368.8148393527191
The new h and w are: 385.27350748452477
The new h and w are: 399.465476459458
The new h and w are: 378.58085480598237
The new h and w are: 399.465476459458
The new h and w are: 453.32153080204694
The new h and w are: 396.3006645577891
The new h and w are: 393.30446831393897
The new h and w are: 422.36640580078557The new h and w are: 362.2593505759632
The new h and w ar

The new h and w are: 382.89958686219796
The new h and w are: 387.79591098099945
The new h and w are: 382.89958686219796The new h and w are: 402.8049289581146

The new h and w are: 459.4336692874594
The new h and w are: 374.8100134752652
The new h and w are: 413.93838832150976
The new h and w are: 362.03867196751236
The new h and w are: 472.60792140656827
The new h and w are: 370.12683126341176
The new h and w are: 479.7059444517019
The new h and w are: 367.62370204758565
The new h and w are: 374.8100134752651
The new h and w are: 364.0328821855136
The new h and w are: 362.9227335058999
The new h and w are: 487.17108638271105
The new h and w are: 371.5617762203221
The new h and w are: 382.89958686219796The new h and w are: 362.2593505759632

The new h and w are: 503.2929803092663
The new h and w are: 382.89958686219796
The new h and w are: 362.53551429007956
The new h and w are: 365.5966332708172
The new h and w are: 422.36640580078557
The new h and w are: 367.62370204758565
The new h a

The new h and w are: 376.6286262078867The new h and w are: 436.69754417973104

The new h and w are: 364.7575192033416
The new h and w are: 363.4216008788375
The new h and w are: 380.66998728722893
The new h and w are: 422.36640580078557
The new h and w are: 402.8049289581146The new h and w are: 406.3255005874387

The new h and w are: 422.36640580078557
The new h and w are: 431.6816230411364
The new h and w are: 373.12198968145515
The new h and w are: 387.79591098099945The new h and w are: 503.2929803092663

The new h and w are: 363.4216008788375The new h and w are: 385.27350748452477

The new h and w are: 410.0341570616378
The new h and w are: 393.30446831393897
The new h and w are: 447.50440903559274The new h and w are: 487.17108638271094The new h and w are: 371.5617762203221


The new h and w are: 362.2593505759632
The new h and w are: 396.3006645577891
The new h and w are: 447.50440903559274
The new h and w are: 385.27350748452477The new h and w are: 503.2929803092663

The new h and

The new h and w are: 495.0254181608458
The new h and w are: 472.60792140656827
The new h and w are: 459.4336692874594The new h and w are: 426.90818316598495

The new h and w are: 393.30446831393897The new h and w are: 364.7575192033416

The new h and w are: 368.8148393527191
The new h and w are: 472.60792140656827
The new h and w are: 453.32153080204694
The new h and w are: 396.3006645577891
The new h and w are: 374.8100134752652
The new h and w are: 367.62370204758565
The new h and w are: 364.75751920334164
The new h and w are: 453.32153080204694
The new h and w are: 453.32153080204694The new h and w are: 402.8049289581146

The new h and w are: 422.36640580078557
The new h and w are: 447.50440903559274
The new h and w are: 362.2593505759632
The new h and w are: 387.79591098099945
The new h and w are: 374.8100134752652
The new h and w are: 406.3255005874386
The new h and w are: 503.2929803092663
The new h and w are: 503.2929803092664
The new h and w are: 402.8049289581146The new h and 

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [100, 1, 224, 224]

In [16]:
torch.cuda.is_available()

True

In [17]:
print("Torch version:",torch.__version__)

print("Is CUDA enabled?",torch.cuda.is_available())

Torch version: 2.1.0+cu118
Is CUDA enabled? True
