In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
## Choose appropriate GPU
%env CUDA_VISIBLE_DEVICES=5
%env OMP_NUM_THREADS=8
%env MKL_NUM_THREADS=8

env: CUDA_VISIBLE_DEVICES=5
env: OMP_NUM_THREADS=8
env: MKL_NUM_THREADS=8


In [4]:
import torch
import torchvision
import torchvision.transforms as transforms

from data import Data2VecDataset

train_dataset = Data2VecDataset('/mnt/data/imagenet/train')
val_dataset = Data2VecDataset('/mnt/data/imagenet/val')

In [5]:
def data2vec_collate_fn(samples):
    imgs, masks = zip(*samples)
    return dict(pixel_values=torch.stack(imgs), bool_masked_pos=torch.stack(masks))

batch_size = 256
num_workers = 16

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, collate_fn=data2vec_collate_fn, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size, collate_fn=data2vec_collate_fn, num_workers=num_workers)

In [6]:
from vit_data2vec import ViTForData2Vec, ViTConfigForData2Vec

config = ViTConfigForData2Vec(
    # data2vec hyperparams
    n_layers_to_average=8,
    huber_loss_delta=2.0,
    
    # ViT-B hyperparams
    hidden_size=768, 
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.0,
    attention_probs_dropout_prob=0.0,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    is_encoder_decoder=False,
    image_size=224,
    patch_size=16,
    num_channels=3,
    qkv_bias=True,
    encoder_stride=16
)

In [7]:
device = 'cuda:0'

In [8]:
model = ViTForData2Vec(config).to(device)

In [9]:
opt = torch.optim.Adam(model.student.parameters(), lr=1e-4)

In [10]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [6]:
import torch
torch.addcmul()

TypeError: addcmul() received an invalid combination of arguments - got (), but expected (Tensor input, Tensor tensor1, Tensor tensor2, *, Number value, Tensor out)

In [None]:
n_epochs = 5
momentum = 0.9998
update_freq = 4
log_freq = 40

for epoch in tqdm(range(n_epochs), desc='Epoch'):
    for step, batch in enumerate(tqdm(train_loader, desc='Batch', leave=False)):
        batch['pixel_values'] = batch['pixel_values'].to(device)
        batch['bool_masked_pos'] = batch['bool_masked_pos'].to(device)
        
        outputs = model(**batch)
        loss = outputs.loss        
        loss.backward()
        
        if step % update_freq == 0:
            opt.step()
            opt.zero_grad()
            
            model.update_teacher(momentum)
        
        if step % log_freq == 0:
            print(loss.item())
            
            # Could look at the distribution of output entries
            #plt.hist(outputs.prediction.detach().cpu().numpy().ravel(), alpha=0.3)
            #plt.hist(outputs.target.detach().cpu().numpy().ravel(), alpha=0.3)
            #plt.ylabel('Frequency')
            #plt.xlabel('Output value')
            #plt.show()

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/5005 [00:00<?, ?it/s]

1.9203659296035767
0.27822282910346985
0.09442088007926941
0.0766567587852478
0.09168350696563721
0.06233667582273483
0.056215494871139526
0.18188852071762085
0.05987240746617317
0.059019941836595535
0.04463276267051697
0.0398867204785347
0.03693309798836708
0.03414800390601158
0.07123082131147385
0.04626847803592682
0.0409870445728302
0.030353739857673645
0.01685752160847187
0.031272247433662415
0.029573682695627213
0.026683809235692024
0.02493894100189209
0.03403769060969353
0.026136333122849464
0.02686288207769394
0.027769921347498894
0.02693517506122589
0.029408009722828865
0.03125990182161331
0.02900662086904049
0.024813253432512283
0.02809232473373413
0.023744763806462288
0.022244183346629143
