# Distillation Debugging

In [3]:
from data_distillation.models.transformer.feature_extractors.triplet_cross_attention_vit import TripletCrossAttentionViT as TCAiT
from data_distillation.losses.triplet_losses.triplet_classification_loss import TripletClassificationLoss as TCLoss

from data_distillation.testing.data.test_triplets import TestTriplets
from data_distillation.data_distiller import DataDistiller

from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch

## Test and debug T-CAiT model

In [2]:
EMBED_DIM = 512
NUM_CLASSES = 21841
NUM_EXTRACTOR_HEADS = 8
NUM_CLASSIFIER_HEADS = 8
BATCH_SIZE = 16
IMG_CHANNELS = 3
IMG_DIM = 224
USE_MINIPATCH = False

In [3]:
# test initialization

model = TCAiT(embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_extractor_heads=NUM_EXTRACTOR_HEADS, num_classifier_heads=NUM_CLASSIFIER_HEADS, in_channels=IMG_CHANNELS, in_dim=IMG_DIM, extractor_use_minipatch=USE_MINIPATCH)

In [4]:
# test forward function

anchor = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
positive = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
negative = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)

In [5]:
if torch.cuda.is_available():
    anchor = anchor.cuda()
    positive = positive.cuda()
    negative = negative.cuda()

    model = model.cuda()

In [6]:
z_anchor, z_positive, z_negative, Y = model(anchor, positive, negative)

  return F.conv2d(input, weight, bias, self.stride,


In [7]:
print(model)

Extractor
Name                                               | Params       | Size                
------------------------------------------------------------------------------------------
patcher.patch_conv.weight                          |       393216 | (512, 3, 16, 16)    
patcher.patch_conv.bias                            |          512 | (512,)              
anchor_cls_tokenizer.cls_tokens                    |          512 | (1, 1, 512)         
positive_cls_tokenizer.cls_tokens                  |          512 | (1, 1, 512)         
negative_cls_tokenizer.cls_tokens                  |          512 | (1, 1, 512)         
anchor_pos_encoder.pos_embedding                   |       100864 | (1, 197, 512)       
positive_pos_encoder.pos_embedding                 |       100864 | (1, 197, 512)       
negative_pos_encoder.pos_embedding                 |       100864 | (1, 197, 512)       
transformer_blocks.0.norm1.weight                  |          512 | (512,)              
transform

## Test DataDistiller object

In [4]:
EMBED_DIM = 128
NUM_CLASSES = 2
NUM_EXTRACTOR_HEADS = 2
NUM_CLASSIFIER_HEADS = 2
BATCH_SIZE = 16
NUM_TRAIN_BATCHES = 10
NUM_VALID_BATCHES = 2
IMG_CHANNELS = 3
IMG_DIM = 224
EXTRACTOR_DEPTH = 4
EXTRACTOR_MLP_RATIO = 2.0
CLASSIFIER_DEPTH = 1
CLASSIFIER_MLP_RATIO = 2.0
USE_MINIPATCH = False

NUM_EPOCHS = 5
SAVE_BEST_WEIGHTS = True
SAVE_DIR = '/Users/charlieclark/Documents/GATech/OMSCS/CichlidBowerTracking/cichlid_bower_tracking/data_distillation/models/weights'
SAVE_FILE = 'test.pt'
SAVE_FP = SAVE_DIR + '/' + SAVE_FILE
DEVICE = 'cpu'

In [5]:
# setup simple T-CAiT model
model = TCAiT(embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_extractor_heads=NUM_EXTRACTOR_HEADS, num_classifier_heads=NUM_CLASSIFIER_HEADS, in_channels=IMG_CHANNELS, in_dim=IMG_DIM, \
              extractor_depth=EXTRACTOR_DEPTH, extractor_mlp_ratio=EXTRACTOR_MLP_RATIO, classifier_depth=CLASSIFIER_DEPTH, classifier_mlp_ratio=CLASSIFIER_MLP_RATIO, extractor_use_minipatch=USE_MINIPATCH)

print(model)

Extractor
Name                                               | Params       | Size                
------------------------------------------------------------------------------------------
patcher.patch_conv.weight                          |        98304 | (128, 3, 16, 16)    
patcher.patch_conv.bias                            |          128 | (128,)              
anchor_cls_tokenizer.cls_tokens                    |          128 | (1, 1, 128)         
positive_cls_tokenizer.cls_tokens                  |          128 | (1, 1, 128)         
negative_cls_tokenizer.cls_tokens                  |          128 | (1, 1, 128)         
anchor_pos_encoder.pos_embedding                   |        25216 | (1, 197, 128)       
positive_pos_encoder.pos_embedding                 |        25216 | (1, 197, 128)       
negative_pos_encoder.pos_embedding                 |        25216 | (1, 197, 128)       
transformer_blocks.0.norm1.weight                  |          128 | (128,)              
transform

In [6]:
# setup datasets and dataloaders
train_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_TRAIN_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)

valid_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_VALID_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE)

In [7]:
len(valid_dataset)

32

In [8]:
# setup optimizer
optimizer = optim.Adam(model.parameters())

In [9]:
# setup loss function
loss_fn = TCLoss()

In [10]:
# set up datadistiller
distiller = DataDistiller(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, model=model, loss_fn=loss_fn, optimizer=optimizer, nepochs=NUM_EPOCHS, nclasses=NUM_CLASSES, save_best_weights=SAVE_BEST_WEIGHTS, save_fp=SAVE_FP, device=DEVICE)

In [11]:
# perform training/validation
distiller.run_main_loop()


---------------------------------------------------------------------------------------------
EPOCH [0/5]
---------------------------------------------------------------------------------------------


Training, Batch [9/10]: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s, accuracy=0.475, loss=20.8]
Validation, Batch [1/2]: 100%|██████████| 2/2 [00:00<00:00,  2.22it/s, accuracy=0.531, loss=13.9]



---------------------------------------------------------------------------------------------
EPOCH [1/5]
---------------------------------------------------------------------------------------------


Training, Batch [9/10]: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s, accuracy=0.519, loss=8.22]
Validation, Batch [1/2]: 100%|██████████| 2/2 [00:00<00:00,  2.18it/s, accuracy=0.625, loss=4.22]



---------------------------------------------------------------------------------------------
EPOCH [2/5]
---------------------------------------------------------------------------------------------


Training, Batch [9/10]: 100%|██████████| 10/10 [00:06<00:00,  1.59it/s, accuracy=0.506, loss=2.64]
Validation, Batch [1/2]: 100%|██████████| 2/2 [00:00<00:00,  2.26it/s, accuracy=0.406, loss=1.41]



---------------------------------------------------------------------------------------------
EPOCH [3/5]
---------------------------------------------------------------------------------------------


Training, Batch [9/10]: 100%|██████████| 10/10 [00:06<00:00,  1.62it/s, accuracy=0.506, loss=0.975]
Validation, Batch [1/2]: 100%|██████████| 2/2 [00:00<00:00,  2.36it/s, accuracy=0.406, loss=0.796]



---------------------------------------------------------------------------------------------
EPOCH [4/5]
---------------------------------------------------------------------------------------------


Training, Batch [9/10]: 100%|██████████| 10/10 [00:06<00:00,  1.60it/s, accuracy=0.419, loss=0.798]
Validation, Batch [1/2]: 100%|██████████| 2/2 [00:00<00:00,  2.29it/s, accuracy=0.562, loss=0.776]


BEST VALIDATION MODEL ACCURACY: 0.4062

Attempting to save best model weights...
	Save successful!



