### Load required packages

In [1]:
import time
import torch
from torch import nn
from ViT import ViT
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
import wandb 

  from .autonotebook import tqdm as notebook_tqdm


### Read in training data

In [2]:
# create an example of training data x and label y; 
# practical case should have more tarining data and split train/test sets
x = np.random.rand(1,224,224,1) 
y = np.random.rand(1,1)

#transfer to torch.tensor
x = torch.from_numpy(x.reshape(1,1,224,224)).float()
y = torch.from_numpy(y.reshape(-1,1)).float()

In [3]:
class Mydata(Dataset):
    
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __getitem__(self, index):
        image = self.data[index]
        label = self.targets[index]
        return image, label
    
    def __len__(self):
        return len(self.data)

dataset = Mydata(x,y)
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 
val_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  
##########################################
### below is for train/test data split ###
##########################################

#train_size = int(0.7*len(dataset))
#test_size = len(dataset) - train_size
#train_set, val_set = random_split(dataset, [train_size, test_size]) 

#train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True) 
#val_dataloader = DataLoader(val_set, batch_size=32, shuffle=True)     

#train_features, train_labels = next(iter(train_dataloader))
#print(f"Feature batch shape: {train_features.size()}")
#print(f"Labels batch shape: {train_labels.size()}")


### Train Vision Transformers (here we use ViT as an example, train SwinT can follow similar way)
> Step1: load ViT model

In [4]:
# train with GPU otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#load model (play with model architectures such as depth, patch size, number of heads...)
model = ViT(
            image_size = 224,
            patch_size = 8,
            num_classes = 1,
            dim = 32, 
            depth = 6,
            heads = 3,
            mlp_dim = 128, 
            dropout = 0.0,
            emb_dropout = 0.0
            )

model.to(device)

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): Linear(in_features=64, out_features=32, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=32, out_features=576, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=192, out_features=32, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=32, out_features=128, bias=True)
              (1): ReLU()
              (2): Dropout(

> step2: define training parameters and train ViT

In [5]:
#define loss
loss_fn = nn.MSELoss()
loss_fn = loss_fn.cuda()

#define optimizer
lr = 1e-6
optimizer = torch.optim.Adam(model.parameters(),lr=lr, eps=1e-8, weight_decay=1e-4)

#traiing parameters
total_train_step = 0
total_val_step = 0
epoch = 300

#train
start_time = time.time()
for i in range(epoch):
    print('----------This is the {} epoch training---------'.format(i+1))
    train_loss = 0
    
    #below is for training
    model.train()
    for imgs, labels in train_dataloader: #for 1 epoch, 66 steps are needed (train data/ batch size)    
        imgs, labels = imgs.to(device), labels.to(device)            
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        #optimizer
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() #update weights        
        total_train_step = total_train_step + 1 #step means iteration, 1 iter means train the number of (batch size) samples        
        train_loss += loss.item()
    print('total train loss:{}'.format(train_loss))
    print('time elapsed: {:.4f}s'.format(time.time()-start_time))
    
    #below is for evaluation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for imgs, labels in val_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            total_val_loss = total_val_loss + loss.item()             
            total_val_step += 1 
      
    print('total val loss:{}'.format(total_val_loss))

----------This is the 1 epoch training---------
total train loss:0.42194271087646484
time elapsed: 1.5211s
total val loss:0.4207490384578705
----------This is the 2 epoch training---------
total train loss:0.4207490384578705
time elapsed: 1.5531s
total val loss:0.4195575416088104
----------This is the 3 epoch training---------
total train loss:0.4195575416088104
time elapsed: 1.5771s
total val loss:0.41836708784103394
----------This is the 4 epoch training---------
total train loss:0.41836708784103394
time elapsed: 1.6081s
total val loss:0.41717857122421265
----------This is the 5 epoch training---------
total train loss:0.41717857122421265
time elapsed: 1.6381s
total val loss:0.4159916043281555
----------This is the 6 epoch training---------
total train loss:0.4159916043281555
time elapsed: 1.6720s
total val loss:0.4148068428039551
----------This is the 7 epoch training---------
total train loss:0.4148068428039551
time elapsed: 1.6990s
total val loss:0.41362348198890686
----------This

> step3: predict the testing data after the model is fully trained (play with different parameter settings and fine tuning)

In [6]:
# use model.predict() to predict properties of interest of testing data
example_test = np.random.rand(1,1,224,224)
model(torch.from_numpy(example_test).float().to(device))

tensor([[0.2341]], device='cuda:0', grad_fn=<AddmmBackward0>)