In [2]:
'''%load_ext autoreload
%autoreload 2
'''
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as tvtransforms
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

import fastmri
from fastmri.data import subsample
from fastmri.data import transforms, mri_data
from fastmri.losses import SSIMLoss

from models.mymodels import FastMRICVT, FastMRIEncoderDecoder
import json
import pytorch_lightning



In [3]:
spec ={
    'INIT': 'trunc_norm',
    'NUM_STAGES': 3,
    'PATCH_SIZE': [5, 3, 3],
    'PATCH_STRIDE': [2, 2, 2],
    'PATCH_PADDING': [1, 1,1],
    'DIM_EMBED': [32, 128, 196],
    'NUM_HEADS': [1, 4, 8],
    'DEPTH': [1, 2, 8],
    'MLP_RATIO': [4.0, 4.0,4.0],
    'ATTN_DROP_RATE': [0.0, 0.0,0.0],
    'DROP_RATE': [0.0, 0.0,0,0],
    'DROP_PATH_RATE': [0.0, 0.0,0.0],
    'QKV_BIAS': [True, True,True],
    'CLS_TOKEN': [False, False, False],
    'POS_EMBED': [False, False, False],
    'QKV_PROJ_METHOD': ['dw_bn', 'dw_bn', 'dw_bn'],
    'KERNEL_QKV': [3, 3, 3],
    'PADDING_KV': [1, 1,1],
    'STRIDE_KV': [2, 2,2],
    'PADDING_Q': [1, 1,1],
    'STRIDE_Q': [1, 1,1]
}

In [3]:
 mask_func = subsample.RandomMaskFunc(
        center_fractions=[0.08, 0.04],
        accelerations=[4, 8]    
        )   
train = mri_data.SliceDataset(
    root='../singlecoil_val',
    transform=transforms.UnetDataTransform('singlecoil', mask_func=mask_func),
    challenge='singlecoil'
)

In [4]:
epochs = 100
learning_rate = 1e-4
weight_decay = 1e-16
batch_size = 4

#device = torch.device('cuda')
device = torch.device('cpu')

model = FastMRICVT(spec=spec).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr =learning_rate, weight_decay=weight_decay)
scheduler= torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
log_every = 10
criterion = SSIMLoss()
criterion.w = criterion.w.to(device)
l2 = torch.nn.MSELoss()

def train_loop(epoch, model, loader):
    model.train()
    i = 0
    results = {'loss': 0, 'counter': 0, 'loss_arr':[]}
    
    for inputs, targets, max_val in loader:
        optimizer.zero_grad()
        pred = model(inputs.to(device))
        loss = criterion(pred,  targets.to(torch.device("cuda")), torch.Tensor(max_val).to(torch.device("cuda")))
        loss.backward()
        optimizer.step()

        results['loss'] += loss.item() * len(inputs)
        results['counter'] += len(inputs)
        results['loss_arr'].append(loss.item())
        if i % log_every == 0:
           print("Train: Epoch: %d \t Iteration: %d \t loss: %.4f" % (epoch, i, sum(results['loss_arr'][-10:])/len(results['loss_arr'][-10:])))
        i += 1

    scheduler.step()
    return 


In [5]:
def val_loop(epoch, model, loader):
   model.eval()
   i = 0
   results = {'loss': 0, 'counter': 0, 'loss_arr':[]}
   with torch.no_grad():
      for inputs, targets, _, _, _, _, max_val in loader:

         inputs = inputs.to(device)[:,None,:,:]
         targets = targets.to(device)[:,None,:,:]
         max_val = torch.Tensor(max_val).to(device)

         pred = model(inputs)
         #loss = criterion(pred, targets, max_val)
         loss = l2(pred, targets)
         results['loss'] += loss.item() * len(inputs)
         results['counter'] += len(inputs)
         results['loss_arr'].append(loss.item())

         if i % log_every == 0:
             print("Val: Epoch %d \t Iteration %d \t loss %.4f" % (epoch, i, sum(results['loss_arr'][-10:])/len(results['loss_arr'][-10:])))
         i += 1
         
   return results['loss']/results['counter']


In [6]:
def main():
    results = {'epochs': [], 'losess': [], 'best_val': 1e10, 'best_epoch': 0}

    for epoch in range(0, epochs):
        #train_loop(epoch, model, train_loader)

        val_loss = val_loop(epoch, model1, val_loader)

        results['epochs'].append(epoch)
        results['losess'].append(val_loss)

        if val_loss < results['best_val']:
            results['best_val'] = val_loss
            results['best_epoch'] = epoch
            
            print("Val loss: %.4f  \t epoch %d" % (val_loss,epoch))
            print("Best: val loss: %.4f \t epoch %d" % (results['best_val'], results['best_epoch']))


        json_object = json.dumps(results, indent=4)
        #with open( + "/" +  + "/losess.json", "w") as outfile:
            #outfile.write(json_object)

In [7]:
train_loader = DataLoader(train, batch_size=batch_size, shuffle = True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle = True)

main()

RuntimeError: Error(s) in loading state_dict for FastMRICVT:
	Missing key(s) in state_dict: "head.1.weight", "head.1.bias", "head.3.weight", "head.3.bias". 
	Unexpected key(s) in state_dict: "stage0.blocks.1.norm1.weight", "stage0.blocks.1.norm1.bias", "stage0.blocks.1.attn.conv_proj_q.conv.weight", "stage0.blocks.1.attn.conv_proj_q.bn.weight", "stage0.blocks.1.attn.conv_proj_q.bn.bias", "stage0.blocks.1.attn.conv_proj_q.bn.running_mean", "stage0.blocks.1.attn.conv_proj_q.bn.running_var", "stage0.blocks.1.attn.conv_proj_q.bn.num_batches_tracked", "stage0.blocks.1.attn.conv_proj_k.conv.weight", "stage0.blocks.1.attn.conv_proj_k.bn.weight", "stage0.blocks.1.attn.conv_proj_k.bn.bias", "stage0.blocks.1.attn.conv_proj_k.bn.running_mean", "stage0.blocks.1.attn.conv_proj_k.bn.running_var", "stage0.blocks.1.attn.conv_proj_k.bn.num_batches_tracked", "stage0.blocks.1.attn.conv_proj_v.conv.weight", "stage0.blocks.1.attn.conv_proj_v.bn.weight", "stage0.blocks.1.attn.conv_proj_v.bn.bias", "stage0.blocks.1.attn.conv_proj_v.bn.running_mean", "stage0.blocks.1.attn.conv_proj_v.bn.running_var", "stage0.blocks.1.attn.conv_proj_v.bn.num_batches_tracked", "stage0.blocks.1.attn.proj_q.weight", "stage0.blocks.1.attn.proj_q.bias", "stage0.blocks.1.attn.proj_k.weight", "stage0.blocks.1.attn.proj_k.bias", "stage0.blocks.1.attn.proj_v.weight", "stage0.blocks.1.attn.proj_v.bias", "stage0.blocks.1.attn.proj.weight", "stage0.blocks.1.attn.proj.bias", "stage0.blocks.1.norm2.weight", "stage0.blocks.1.norm2.bias", "stage0.blocks.1.mlp.fc1.weight", "stage0.blocks.1.mlp.fc1.bias", "stage0.blocks.1.mlp.fc2.weight", "stage0.blocks.1.mlp.fc2.bias", "stage1.blocks.2.norm1.weight", "stage1.blocks.2.norm1.bias", "stage1.blocks.2.attn.conv_proj_q.conv.weight", "stage1.blocks.2.attn.conv_proj_q.bn.weight", "stage1.blocks.2.attn.conv_proj_q.bn.bias", "stage1.blocks.2.attn.conv_proj_q.bn.running_mean", "stage1.blocks.2.attn.conv_proj_q.bn.running_var", "stage1.blocks.2.attn.conv_proj_q.bn.num_batches_tracked", "stage1.blocks.2.attn.conv_proj_k.conv.weight", "stage1.blocks.2.attn.conv_proj_k.bn.weight", "stage1.blocks.2.attn.conv_proj_k.bn.bias", "stage1.blocks.2.attn.conv_proj_k.bn.running_mean", "stage1.blocks.2.attn.conv_proj_k.bn.running_var", "stage1.blocks.2.attn.conv_proj_k.bn.num_batches_tracked", "stage1.blocks.2.attn.conv_proj_v.conv.weight", "stage1.blocks.2.attn.conv_proj_v.bn.weight", "stage1.blocks.2.attn.conv_proj_v.bn.bias", "stage1.blocks.2.attn.conv_proj_v.bn.running_mean", "stage1.blocks.2.attn.conv_proj_v.bn.running_var", "stage1.blocks.2.attn.conv_proj_v.bn.num_batches_tracked", "stage1.blocks.2.attn.proj_q.weight", "stage1.blocks.2.attn.proj_q.bias", "stage1.blocks.2.attn.proj_k.weight", "stage1.blocks.2.attn.proj_k.bias", "stage1.blocks.2.attn.proj_v.weight", "stage1.blocks.2.attn.proj_v.bias", "stage1.blocks.2.attn.proj.weight", "stage1.blocks.2.attn.proj.bias", "stage1.blocks.2.norm2.weight", "stage1.blocks.2.norm2.bias", "stage1.blocks.2.mlp.fc1.weight", "stage1.blocks.2.mlp.fc1.bias", "stage1.blocks.2.mlp.fc2.weight", "stage1.blocks.2.mlp.fc2.bias", "stage1.blocks.3.norm1.weight", "stage1.blocks.3.norm1.bias", "stage1.blocks.3.attn.conv_proj_q.conv.weight", "stage1.blocks.3.attn.conv_proj_q.bn.weight", "stage1.blocks.3.attn.conv_proj_q.bn.bias", "stage1.blocks.3.attn.conv_proj_q.bn.running_mean", "stage1.blocks.3.attn.conv_proj_q.bn.running_var", "stage1.blocks.3.attn.conv_proj_q.bn.num_batches_tracked", "stage1.blocks.3.attn.conv_proj_k.conv.weight", "stage1.blocks.3.attn.conv_proj_k.bn.weight", "stage1.blocks.3.attn.conv_proj_k.bn.bias", "stage1.blocks.3.attn.conv_proj_k.bn.running_mean", "stage1.blocks.3.attn.conv_proj_k.bn.running_var", "stage1.blocks.3.attn.conv_proj_k.bn.num_batches_tracked", "stage1.blocks.3.attn.conv_proj_v.conv.weight", "stage1.blocks.3.attn.conv_proj_v.bn.weight", "stage1.blocks.3.attn.conv_proj_v.bn.bias", "stage1.blocks.3.attn.conv_proj_v.bn.running_mean", "stage1.blocks.3.attn.conv_proj_v.bn.running_var", "stage1.blocks.3.attn.conv_proj_v.bn.num_batches_tracked", "stage1.blocks.3.attn.proj_q.weight", "stage1.blocks.3.attn.proj_q.bias", "stage1.blocks.3.attn.proj_k.weight", "stage1.blocks.3.attn.proj_k.bias", "stage1.blocks.3.attn.proj_v.weight", "stage1.blocks.3.attn.proj_v.bias", "stage1.blocks.3.attn.proj.weight", "stage1.blocks.3.attn.proj.bias", "stage1.blocks.3.norm2.weight", "stage1.blocks.3.norm2.bias", "stage1.blocks.3.mlp.fc1.weight", "stage1.blocks.3.mlp.fc1.bias", "stage1.blocks.3.mlp.fc2.weight", "stage1.blocks.3.mlp.fc2.bias", "stage1.blocks.4.norm1.weight", "stage1.blocks.4.norm1.bias", "stage1.blocks.4.attn.conv_proj_q.conv.weight", "stage1.blocks.4.attn.conv_proj_q.bn.weight", "stage1.blocks.4.attn.conv_proj_q.bn.bias", "stage1.blocks.4.attn.conv_proj_q.bn.running_mean", "stage1.blocks.4.attn.conv_proj_q.bn.running_var", "stage1.blocks.4.attn.conv_proj_q.bn.num_batches_tracked", "stage1.blocks.4.attn.conv_proj_k.conv.weight", "stage1.blocks.4.attn.conv_proj_k.bn.weight", "stage1.blocks.4.attn.conv_proj_k.bn.bias", "stage1.blocks.4.attn.conv_proj_k.bn.running_mean", "stage1.blocks.4.attn.conv_proj_k.bn.running_var", "stage1.blocks.4.attn.conv_proj_k.bn.num_batches_tracked", "stage1.blocks.4.attn.conv_proj_v.conv.weight", "stage1.blocks.4.attn.conv_proj_v.bn.weight", "stage1.blocks.4.attn.conv_proj_v.bn.bias", "stage1.blocks.4.attn.conv_proj_v.bn.running_mean", "stage1.blocks.4.attn.conv_proj_v.bn.running_var", "stage1.blocks.4.attn.conv_proj_v.bn.num_batches_tracked", "stage1.blocks.4.attn.proj_q.weight", "stage1.blocks.4.attn.proj_q.bias", "stage1.blocks.4.attn.proj_k.weight", "stage1.blocks.4.attn.proj_k.bias", "stage1.blocks.4.attn.proj_v.weight", "stage1.blocks.4.attn.proj_v.bias", "stage1.blocks.4.attn.proj.weight", "stage1.blocks.4.attn.proj.bias", "stage1.blocks.4.norm2.weight", "stage1.blocks.4.norm2.bias", "stage1.blocks.4.mlp.fc1.weight", "stage1.blocks.4.mlp.fc1.bias", "stage1.blocks.4.mlp.fc2.weight", "stage1.blocks.4.mlp.fc2.bias", "stage1.blocks.5.norm1.weight", "stage1.blocks.5.norm1.bias", "stage1.blocks.5.attn.conv_proj_q.conv.weight", "stage1.blocks.5.attn.conv_proj_q.bn.weight", "stage1.blocks.5.attn.conv_proj_q.bn.bias", "stage1.blocks.5.attn.conv_proj_q.bn.running_mean", "stage1.blocks.5.attn.conv_proj_q.bn.running_var", "stage1.blocks.5.attn.conv_proj_q.bn.num_batches_tracked", "stage1.blocks.5.attn.conv_proj_k.conv.weight", "stage1.blocks.5.attn.conv_proj_k.bn.weight", "stage1.blocks.5.attn.conv_proj_k.bn.bias", "stage1.blocks.5.attn.conv_proj_k.bn.running_mean", "stage1.blocks.5.attn.conv_proj_k.bn.running_var", "stage1.blocks.5.attn.conv_proj_k.bn.num_batches_tracked", "stage1.blocks.5.attn.conv_proj_v.conv.weight", "stage1.blocks.5.attn.conv_proj_v.bn.weight", "stage1.blocks.5.attn.conv_proj_v.bn.bias", "stage1.blocks.5.attn.conv_proj_v.bn.running_mean", "stage1.blocks.5.attn.conv_proj_v.bn.running_var", "stage1.blocks.5.attn.conv_proj_v.bn.num_batches_tracked", "stage1.blocks.5.attn.proj_q.weight", "stage1.blocks.5.attn.proj_q.bias", "stage1.blocks.5.attn.proj_k.weight", "stage1.blocks.5.attn.proj_k.bias", "stage1.blocks.5.attn.proj_v.weight", "stage1.blocks.5.attn.proj_v.bias", "stage1.blocks.5.attn.proj.weight", "stage1.blocks.5.attn.proj.bias", "stage1.blocks.5.norm2.weight", "stage1.blocks.5.norm2.bias", "stage1.blocks.5.mlp.fc1.weight", "stage1.blocks.5.mlp.fc1.bias", "stage1.blocks.5.mlp.fc2.weight", "stage1.blocks.5.mlp.fc2.bias", "head.4.weight", "head.4.bias". 
	size mismatch for head.0.weight: copying a param with shape torch.Size([192, 48, 3, 3]) from checkpoint, the shape in current model is torch.Size([192, 48, 2, 2]).

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 109461919.61it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
