# Mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Install requirements

In [2]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-1.5.10-py3-none-any.whl (527 kB)
[?25l[K     |▋                               | 10 kB 32.1 MB/s eta 0:00:01[K     |█▎                              | 20 kB 37.8 MB/s eta 0:00:01[K     |█▉                              | 30 kB 21.0 MB/s eta 0:00:01[K     |██▌                             | 40 kB 17.4 MB/s eta 0:00:01[K     |███                             | 51 kB 17.0 MB/s eta 0:00:01[K     |███▊                            | 61 kB 15.6 MB/s eta 0:00:01[K     |████▍                           | 71 kB 13.6 MB/s eta 0:00:01[K     |█████                           | 81 kB 15.1 MB/s eta 0:00:01[K     |█████▋                          | 92 kB 13.7 MB/s eta 0:00:01[K     |██████▏                         | 102 kB 14.7 MB/s eta 0:00:01[K     |██████▉                         | 112 kB 14.7 MB/s eta 0:00:01[K     |███████▌                        | 122 kB 14.7 MB/s eta 0:00:01[K     |████████                        | 1

# Import relevant libraries

In [3]:
import sys
import os
module_path = os.path.join(os.path.abspath(''), "drive/MyDrive/Master_Thesis_Profactor/zdmp/")
if module_path not in sys.path:
  sys.path.append(module_path)

import zdmp
from utils.lightning_classifier import Classifier
from utils.external_utils.vision_transformer import VitGenerator

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader

# Train ViT Model

In [4]:

def save_model(model, name):
    path = os.path.join(
        "/content/drive/MyDrive/Master_Thesis_Profactor/zdmp/pretrained_weights",
        f"{name}.pth.tar"
    )
    torch.save(model.state_dict(), path)

In [5]:
# Make the results reproducible
pl.seed_everything(42)

# Prepare dataset
ds_train = zdmp.get_train_data(vit=True)
ds_test = zdmp.get_test_data(vit=True)

# Prepare Dataloaders
dl_train = DataLoader(ds_train, batch_size=16, num_workers=2, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=16, num_workers=2)

# Prepare model
model = VitGenerator('vit_small', 
                     8, 
                     'cuda', 
                     evaluate=False, 
                     random=False, 
                     verbose=True,
                     num_classes=2
                     )

# Prepare Classifier to train the model
classifier = Classifier(model.model)

# Prepare the Trainer
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=3)

# Train
trainer.fit(classifier, dl_train, dl_test)

# Save the trained weights
save_model(classifier.model, "vit")

Global seed set to 42


zdmp - ViT size - True
get_data - ViT size - True
transform - ViT size - True
[INFO] Initializing vit_small with patch size of 8
[INFO] Loading weights


Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall8_300ep_pretrain.pth


  0%|          | 0.00/82.7M [00:00<?, ?B/s]

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | VisionTransformer | 21.7 M
1 | train_metrics | MetricCollection  | 0     
2 | valid_metrics | MetricCollection  | 0     
----------------------------------------------------
21.7 M    Trainable params
0         Non-trainable params
21.7 M    Total params
43.342    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Tensorboard
For both train and validation:
* Accuracy
* Precission
* Recall
* Loss

In [6]:
!tensorboard dev upload --logdir '/content/lightning_logs'


***** TensorBoard Uploader *****

This will upload your TensorBoard logs to https://tensorboard.dev/ from
the following directory:

/content/lightning_logs

This TensorBoard will be visible to everyone. Do not upload sensitive
data.

Your use of this service is subject to Google's Terms of Service
<https://policies.google.com/terms> and Privacy Policy
<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service
<https://tensorboard.dev/policy/terms/>.

This notice will not be shown again while you are logged into the uploader.
To log out, run `tensorboard dev auth revoke`.

Continue? (yes/NO) y

Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&state=jozjFMttvGgDeGOkA0TYOxcp1pZwhY&prompt=consent&access_type

# Test trained model

In [7]:
wrong = 0
for i, (img, lbl) in enumerate(ds_test):
    out = int(model(img.unsqueeze(0)).softmax(1).argmax())
    if out != lbl:
        wrong+=1
    print(f"{i}\tPred: {out}\tTrue:{lbl}\t{'' if out==lbl else '*'*5}")
print(f"Wrong: {wrong}/{len(ds_test)}")

0	Pred: 1	True:1	
1	Pred: 1	True:1	
2	Pred: 1	True:1	
3	Pred: 1	True:1	
4	Pred: 0	True:0	
5	Pred: 0	True:0	
6	Pred: 0	True:0	
7	Pred: 0	True:0	
8	Pred: 1	True:1	
9	Pred: 0	True:0	
10	Pred: 0	True:0	
11	Pred: 1	True:1	
12	Pred: 0	True:0	
13	Pred: 0	True:0	
14	Pred: 1	True:1	
15	Pred: 1	True:1	
16	Pred: 0	True:0	
17	Pred: 0	True:0	
18	Pred: 1	True:1	
19	Pred: 1	True:1	
20	Pred: 1	True:1	
21	Pred: 0	True:0	
22	Pred: 0	True:0	
23	Pred: 1	True:1	
24	Pred: 0	True:0	
25	Pred: 1	True:1	
26	Pred: 0	True:0	
27	Pred: 0	True:0	
28	Pred: 0	True:0	
29	Pred: 0	True:0	
30	Pred: 1	True:1	
31	Pred: 0	True:0	
32	Pred: 1	True:1	
33	Pred: 0	True:0	
34	Pred: 0	True:0	
35	Pred: 0	True:0	
36	Pred: 0	True:0	
37	Pred: 0	True:0	
38	Pred: 0	True:0	
39	Pred: 0	True:0	
40	Pred: 1	True:1	
41	Pred: 1	True:1	
42	Pred: 1	True:1	
43	Pred: 0	True:0	
44	Pred: 0	True:0	
45	Pred: 1	True:1	
46	Pred: 0	True:0	
47	Pred: 1	True:1	
48	Pred: 1	True:1	
49	Pred: 0	True:0	
50	Pred: 0	True:0	
51	Pred: 0	True:0	
52	Pred: 0	True:0	
53	

# Visualize Attention