In [1]:
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu116
Note: you may need to restart the kernel to use updated packages.


In [2]:
from models.DifViT import ViT as DifViT
from diffusion import GaussianDiffusion

SPT = 'Shifted Patch Tokenization'
LSA = Locality Self-Attention
FFNT = If add time embedding before FFN layer

In [3]:
model = DifViT(img_size=9, patch_size=3, num_classes=1, dim=96,
                mlp_dim_ratio=2, depth=12, heads=12, dim_head=8, channels=3,
                stochastic_depth=0, is_SPT=0, is_LSA=0, ffn_time=0)

02-05 01:11:04: vit.py[121]: Use time embedding before FFN? 0


In [4]:
diffusion_model = GaussianDiffusion(
        model,
        image_size=9,
        channels=1,
        timesteps=1000,   # number of steps
        loss_type='l1'    # L1 or L2
    )

In [5]:
import torch
dummy = torch.randn(2, 3, 9,9) # (batch x channels x frames x height x width)
temp = diffusion_model(dummy)
temp.shape

torch.Size([2, 3, 9, 9])

In [6]:
import json
import numpy as np


# Load data

with open("/home/yetian/glaucoma_progression/uwhvf/alldata.json") as fin:
  dat = json.loads(fin.read())

# Basic statistics

print(f"Total of {dat['pts']} patients, {dat['eyes']} eyes, and {dat['hvfs']} HVFs")
# Expected output: Total of 3871 patients, 7428 eyes, and 28943 HVFs

Total of 3871 patients, 7428 eyes, and 28943 HVFs


In [7]:
datalist = [] #7248, 3, 9, 9
labellist = [] #7428, 3, 9, 9

In [8]:
key_list = []
for key in dat['data'].keys():
    key_list.append(key)

In [9]:
def hundred_to_zero(temp2):
    for i in range(temp2.shape[0]):
        for j in range(temp2.shape[1]):
            if temp2[i][j] == 100:
                temp2[i][j] = float(0)
    return temp2

def duplicate(temp):
    a = temp[np.newaxis, :]
    a = np.repeat(a, 3, axis=0)
    return a

In [10]:
gap_list = []
for key in key_list:
    if 'L' in dat['data'][key].keys(): 
        if len(dat['data'][key]['L']) > 4:
            age_diff = dat['data'][key]['L'][4]['age'] - dat['data'][key]['L'][0]['age']
            gap_list.append(age_diff)

for key in key_list:
    if 'R' in dat['data'][key].keys(): 
        if len(dat['data'][key]['R']) > 4:
            age_diff = dat['data'][key]['R'][4]['age'] - dat['data'][key]['R'][0]['age']
            gap_list.append(age_diff)

print('total number of available eyes', len(gap_list))
print('average',sum(gap_list) / len(gap_list))
import statistics
print('median',statistics.median(gap_list))

counter = 0
for gap in gap_list:
    if gap >= 5:
        counter +=1
print('number of >5 yr pairs', counter)
print('ratio of >5 yr pairs', counter/len(gap_list))
    

total number of available eyes 2065
average 5.411749887718299
median 4.6570841889117105
number of >5 yr pairs 914
ratio of >5 yr pairs 0.44261501210653753


Example of ages when VF test taken for one subject:

In [11]:
print(dat['data']['2630']['L'][0]['age'])
print(dat['data']['2630']['L'][1]['age'])
print(dat['data']['2630']['L'][2]['age'])

45.22108145106092
52.982888432580424
53.94113620807666


In [12]:
datalist = []
labellist = [] 
for key in key_list:
    if 'L' in dat['data'][key].keys():  #['L'] or ['R']
        if len(dat['data'][key]['L']) > 2:  # at least 2 frames
            temp0 = np.array(dat['data'][key]['L'][0]['td'])
            temp1 = np.array(dat['data'][key]['L'][2]['td'])

            newtemp0 = np.pad(temp0, pad_width=((0,1),(0,0)), mode='constant')   #9,9
            newtemp0 = hundred_to_zero(newtemp0)
            #newtemp0 = duplicate(newtemp0)                                       # 3,9,9
            newtemp0 = newtemp0[np.newaxis,:]

            newtemp1 = np.pad(temp1, pad_width=((0,1),(0,0)), mode='constant')
            newtemp1 = hundred_to_zero(newtemp1)
            #newtemp1 = duplicate(newtemp1)
            newtemp1 = newtemp1[np.newaxis,:]

            datalist.append(newtemp0)
            labellist.append(newtemp1)

for key in key_list:
    if 'R' in dat['data'][key].keys():  #['L'] or ['R']
        if len(dat['data'][key]['R']) > 2:  # at least 2 frames
            temp0 = np.array(dat['data'][key]['R'][0]['td'])
            temp1 = np.array(dat['data'][key]['R'][2]['td'])
            
            newtemp0 = np.pad(temp0, pad_width=((0,1),(0,0)), mode='constant')   #9,9
            newtemp0 = hundred_to_zero(newtemp0)
            #newtemp0 = duplicate(newtemp0)                                       # 3,9,9
            newtemp0 = newtemp0[np.newaxis,:]

            newtemp1 = np.pad(temp1, pad_width=((0,1),(0,0)), mode='constant')
            newtemp1 = hundred_to_zero(newtemp1)
            #newtemp1 = duplicate(newtemp1)
            newtemp1 = newtemp1[np.newaxis,:]

            datalist.append(newtemp0)
            labellist.append(newtemp1)

Make sure input and label have same length:

In [13]:
print(len(datalist)) # datalist shape = (3,3,9,9)
print(len(labellist))

4452
4452


In [14]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader 
#from torchvision import datasets, models, transforms
from PIL import Image

# defining the Dataset class
class VFDataset(Dataset):
    def __init__(self, label, img, transform=None):
        self.label = label
        self.img = img
        self.transform = transform

    def __len__(self):
        return len(self.label)
  
    def __getitem__(self, index):
        image = self.img[index]
        label = self.label[index]

        return image, label

In [15]:
VF_datasets = {
    'train': 
    VFDataset(img = datalist, label = labellist,
                transform=None),
}

train_set, val_set = torch.utils.data.random_split(VF_datasets['train'], [3561,891])

dataloaders = {
    'train':
    torch.utils.data.DataLoader(train_set,
                                batch_size=5,
                                shuffle=True),
    'validation':
    torch.utils.data.DataLoader(val_set,
                                #val_set,
                                
                                batch_size=5,
                                shuffle=False)
}


In [16]:
import torch
import torch.nn as nn

class GENVIT(nn.Module):
  def __init__(self):
    super(GENVIT, self).__init__()
    self.model = DifViT(img_size=9, patch_size=1, num_classes=1, dim=192,
                mlp_dim_ratio=2, depth=12, heads=12, dim_head=16, channels=1,
                stochastic_depth=0, is_SPT=0, is_LSA=1, ffn_time=1)

    self.diffusion_model = GaussianDiffusion(
        self.model,
        image_size=9,
        channels=1,
        timesteps=1000,   # number of steps
        loss_type='l1'    # L1 or L2
    )

  def forward(self, x):

    x = self.diffusion_model(x)
    return x

In [17]:
model = GENVIT() 

print(model)

02-05 01:11:08: vit.py[121]: Use time embedding before FFN? 1


GENVIT(
  (model): ViT(
    (time_mlp): Sequential(
      (0): SinusoidalPosEmb()
      (1): Linear(in_features=192, out_features=328, bias=True)
      (2): GELU(approximate=none)
      (3): Linear(in_features=328, out_features=328, bias=True)
    )
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=1, p2=1)
      (1): Linear(in_features=1, out_features=192, bias=True)
    )
    (recon_head): Sequential(
      (0): Linear(in_features=192, out_features=1, bias=True)
      (1): Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=9, w=9, p1=1, p2=1)
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (transformer): Transformer(
      (layers): ModuleList(
        (0): ModuleList(
          (0): PreNorm(
            (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (fn): Attention(
              (attend): Softmax(dim=-1)
              (to_qkv): Linear(in_features=192, out_features=576, bias=False)
           

In [18]:
video = torch.randn(5, 1, 9,9)
pred = model(video)
pred.shape

torch.Size([5, 1, 9, 9])

In [19]:

import torch.optim as optim
from tqdm import tqdm

criterion = nn.MSELoss()
MAE = nn.L1Loss()
criterion = MAE
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [20]:
import time
import copy

def train_model(model, criterion, optimizer, num_epochs=3):
    since = time.time()

    best_mse = 10000

    train_mse = []
    val_mse = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            train_preds = []
            train_trues = []

            running_loss = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):

               # labels = labels.unsqueeze(1)
               # print(inputs.shape)
                outputs = model(inputs.float())
              #  print(outputs.shape)
              #  print('outputs:', outputs)
              
                loss = criterion(outputs, labels.float())
              #  print('loss')

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()


                preds_array = outputs.detach().cpu().numpy()          
                labels_array = labels.data.detach().cpu().numpy()

                train_preds.extend(preds_array)
                train_trues.extend(labels_array)

                running_loss += loss.item() * inputs.size(0)

            if phase == 'train':
              epoch_loss = running_loss / len(train_set)
            else:
              epoch_loss = running_loss / len(val_set)
              


            print('{} loss: {:.4f}'.format(phase,epoch_loss))

            if phase == 'train':
              train_mse.append(epoch_loss)

            if phase == 'validation':
              val_mse.append(epoch_loss)

            if phase == 'validation' and epoch_loss < best_mse:
                save_labels = train_trues
                save_preds = train_preds

                best_mse = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                print('best_mae:', best_mse)
                print()
                print('A new best model saved at epoch {}!'.format(epoch + 1))


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val mae: {:4f}'.format(best_mse))            

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_mse, val_mse, best_mse, save_labels, save_preds

In [None]:
model_trained = train_model(model, criterion, optimizer, num_epochs=10)

In [None]:
example1 = datalist[1]
example1 = example1[np.newaxis, :]
example1 = torch.from_numpy(example1)
pred1 = model(example1.float())
print(example1)
print(pred1)
print('MAE', MAE(pred1, example1))

tensor([[[[  0.0000,   0.0000,   0.0000, -10.4800,  -6.9000,  -7.6700, -17.3300,
             0.0000,   0.0000],
          [  0.0000,   0.0000,  -8.0200,  -5.9500,  -4.4200,  -8.0500,  -6.6500,
            -3.8300,   0.0000],
          [  0.0000, -11.6000,  -8.9100,  -5.5200,  -9.2500,  -6.5600,  -5.4000,
            -2.7900,  -2.5200],
          [ -8.4000,  -4.9200,  -7.0200,  -8.8300,  -7.6000,  -6.2800,  -4.8000,
            24.0000,  -5.8100],
          [ -3.8000,  -5.7400,  -3.8800,  -4.1100,  -7.8800,  -5.4300,  -2.4900,
             0.0000,  -3.3600],
          [  0.0000,  -8.0600,  -6.3200,  -3.9400,  -7.4100,  -2.0100,  -3.5000,
            -5.7600,  -3.2100],
          [  0.0000,   0.0000,  -4.8600,  -6.5800,  -8.1200,  -4.7300,  -2.1100,
            -2.3500,   0.0000],
          [  0.0000,   0.0000,   0.0000,  -9.5600,  -4.7800,  -2.4200,  -2.7500,
             0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
             0.