# Distillation Debugging

In [1]:
from data_distillation.models.transformer.feature_extractors.triplet_cross_attention_vit import TripletCrossAttentionViT as TCAiT
from data_distillation.models.transformer.feature_extractors.pyramid.pyra_tcait import PyraTCAiT

from data_distillation.losses.triplet_losses.triplet_classification_loss import TripletClassificationLoss as TCLoss
from data_distillation.losses.triplet_losses.triplet_loss import TripletLoss

from data_distillation.testing.data.test_triplets import TestTriplets
from data_distillation.data_distiller import DataDistiller

from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch

## Test and debug TCAiT model

In [2]:
EMBED_DIM = 512
NUM_CLASSES = 21841
NUM_EXTRACTOR_HEADS = 8
NUM_CLASSIFIER_HEADS = 8
BATCH_SIZE = 16
IMG_CHANNELS = 3
IMG_DIM = 224
USE_MINIPATCH = False

In [3]:
# test initialization

model = TCAiT(embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_extractor_heads=NUM_EXTRACTOR_HEADS, num_classifier_heads=NUM_CLASSIFIER_HEADS, in_channels=IMG_CHANNELS, in_dim=IMG_DIM, extractor_use_minipatch=USE_MINIPATCH)

In [4]:
# test forward function

anchor = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
positive = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
negative = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)

In [5]:
if torch.cuda.is_available():
    anchor = anchor.cuda()
    positive = positive.cuda()
    negative = negative.cuda()

    model = model.cuda()

In [6]:
z_anchor, z_positive, z_negative, Y = model(anchor, positive, negative)

In [7]:
# test string function

print(model)

Extractor
Name                                                                   | Params       | Size                
--------------------------------------------------------------------------------------------------------------
patcher.patch_conv.weight                                              |       393216 | (512, 3, 16, 16)    
patcher.patch_conv.bias                                                |          512 | (512,)              
anchor_cls_tokenizer.cls_tokens                                        |          512 | (1, 1, 512)         
positive_cls_tokenizer.cls_tokens                                      |          512 | (1, 1, 512)         
negative_cls_tokenizer.cls_tokens                                      |          512 | (1, 1, 512)         
anchor_pos_encoder.pos_embedding                                       |       100864 | (1, 197, 512)       
positive_pos_encoder.pos_embedding                                     |       100864 | (1, 197, 512)       
negativ

## Test and debug PyraTCAiT model

In [2]:
EMBED_DIMS = [64, 128, 320, 512]
HEAD_COUNTS = [1, 2, 5, 8]
MLP_RATIOS = [8, 8, 4, 4]
SR_RATIOS = [8, 4, 2, 1]
DEPTHS = [3, 3, 6, 3]
ADD_CLASSIFIER = True
NUM_CLASSES = 2

In [3]:
# test initialization

model = PyraTCAiT(embed_dims=EMBED_DIMS, head_counts=HEAD_COUNTS, mlp_ratios=MLP_RATIOS, sr_ratios=SR_RATIOS, depths=DEPTHS,
                  add_classifier=ADD_CLASSIFIER, num_classes=NUM_CLASSES)

In [4]:
BATCH_SIZE = 16
IMG_CHANNELS = 3
IMG_DIM = 224

In [5]:
# test forward function

anchor = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
positive = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)
negative = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_DIM, IMG_DIM)

In [6]:
if torch.cuda.is_available():
    anchor = anchor.cuda()
    positive = positive.cuda()
    negative = negative.cuda()

    model = model.cuda()

In [7]:
z_anchor, z_positive, z_negative, Y = model(anchor, positive, negative)

In [8]:
# test string function

print(model)

Stage 0
Name                                                                   | Params       | Size                
--------------------------------------------------------------------------------------------------------------
patcher.patch_conv.weight                                              |         3072 | (64, 3, 4, 4)       
patcher.patch_conv.bias                                                |           64 | (64,)               
patcher.norm.weight                                                    |           64 | (64,)               
patcher.norm.bias                                                      |           64 | (64,)               
anchor_pos_encoder.pos_embedding                                       |       200704 | (1, 3136, 64)       
positive_pos_encoder.pos_embedding                                     |       200704 | (1, 3136, 64)       
negative_pos_encoder.pos_embedding                                     |       200704 | (1, 3136, 64)       
transform

## Test DataDistiller object

In [9]:
EMBED_DIM = 128
NUM_CLASSES = 2
NUM_EXTRACTOR_HEADS = 2
NUM_CLASSIFIER_HEADS = 2
BATCH_SIZE = 16
NUM_TRAIN_BATCHES = 4
NUM_VALID_BATCHES = 1
IMG_CHANNELS = 3
IMG_DIM = 224
EXTRACTOR_DEPTH = 4
EXTRACTOR_MLP_RATIO = 2.0
CLASSIFIER_DEPTH = 1
CLASSIFIER_MLP_RATIO = 2.0
USE_MINIPATCH = False

NUM_EPOCHS = 5
SAVE_BEST_WEIGHTS = True
SAVE_DIR = '/Users/charlieclark/Documents/GATech/OMSCS/CichlidBowerTracking/cichlid_bower_tracking/data_distillation/models/weights'
SAVE_FILE = 'test.pt'
SAVE_FP = SAVE_DIR + '/' + SAVE_FILE
DEVICE = 'cpu'
GPU_ID = -1

In [10]:
# setup simple T-CAiT model
model = TCAiT(embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_extractor_heads=NUM_EXTRACTOR_HEADS, num_classifier_heads=NUM_CLASSIFIER_HEADS, in_channels=IMG_CHANNELS, in_dim=IMG_DIM, \
              extractor_depth=EXTRACTOR_DEPTH, extractor_mlp_ratio=EXTRACTOR_MLP_RATIO, classifier_depth=CLASSIFIER_DEPTH, classifier_mlp_ratio=CLASSIFIER_MLP_RATIO, extractor_use_minipatch=USE_MINIPATCH)

print(model)

Extractor
Name                                                                   | Params       | Size                
--------------------------------------------------------------------------------------------------------------
patcher.patch_conv.weight                                              |        98304 | (128, 3, 16, 16)    
patcher.patch_conv.bias                                                |          128 | (128,)              
anchor_cls_tokenizer.cls_tokens                                        |          128 | (1, 1, 128)         
positive_cls_tokenizer.cls_tokens                                      |          128 | (1, 1, 128)         
negative_cls_tokenizer.cls_tokens                                      |          128 | (1, 1, 128)         
anchor_pos_encoder.pos_embedding                                       |        25216 | (1, 197, 128)       
positive_pos_encoder.pos_embedding                                     |        25216 | (1, 197, 128)       
negativ

In [11]:
# setup datasets and dataloaders
train_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_TRAIN_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)

valid_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_VALID_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE)

In [12]:
len(valid_dataset)

300

In [13]:
# setup optimizer
optimizer = optim.Adam(model.parameters())

In [14]:
# setup loss function
loss_fn = TCLoss()

In [15]:
# set up datadistiller
distiller = DataDistiller(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, model=model, loss_fn=loss_fn, optimizer=optimizer, nepochs=NUM_EPOCHS, nclasses=NUM_CLASSES, save_best_weights=SAVE_BEST_WEIGHTS, save_fp=SAVE_FP, device=DEVICE, gpu_id=GPU_ID)

TypeError: __init__() got an unexpected keyword argument 'save_best_weights'

In [None]:
# perform training/validation
distiller.run_main_loop()


---------------------------------------------------------------------------------------------
EPOCH [0/5]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [00:11<00:00,  1.66it/s, accuracy=0.518, loss=12.5]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:05<00:00,  3.47it/s, accuracy=0.548, loss=5.72]



---------------------------------------------------------------------------------------------
EPOCH [1/5]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [00:09<00:00,  1.98it/s, accuracy=0.473, loss=3.76]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:05<00:00,  3.36it/s, accuracy=0.51, loss=1.18]



---------------------------------------------------------------------------------------------
EPOCH [2/5]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [00:10<00:00,  1.80it/s, accuracy=0.496, loss=0.906]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:05<00:00,  3.42it/s, accuracy=0.493, loss=0.852]



---------------------------------------------------------------------------------------------
EPOCH [3/5]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [00:10<00:00,  1.78it/s, accuracy=0.453, loss=0.896]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:05<00:00,  3.42it/s, accuracy=0.5, loss=1.05]



---------------------------------------------------------------------------------------------
EPOCH [4/5]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [00:11<00:00,  1.62it/s, accuracy=0.43, loss=0.955] 
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:05<00:00,  3.55it/s, accuracy=0.492, loss=0.816]



BEST VALIDATION MODEL ACCURACY: 0.4923

Attempting to save best model weights...
	Save successful!


In [None]:
EMBED_DIMS = [12, 24, 48, 96]
HEAD_COUNTS = [1, 2, 4, 6]
MLP_RATIOS = [4, 4, 2, 2]
SR_RATIOS = [8, 4, 2, 1]
DEPTHS = [1, 2, 4, 2]
ADD_CLASSIFIER = True
NUM_CLASSES = 2
INIT_ALPHA = 0.1
INIT_BETA = 0.1
USE_IMPROVED = True

NUM_EPOCHS = 2
SAVE_BEST_WEIGHTS = True
SAVE_DIR = '/Users/charlieclark/Documents/GATech/OMSCS/CichlidBowerTracking/cichlid_bower_tracking/data_distillation/models/weights'
SAVE_FILE = 'test2.pt'
SAVE_FP = SAVE_DIR + '/' + SAVE_FILE
DEVICE = 'cpu'
GPU_ID = '-1'

In [None]:
# setup simple PyraT-CAiT model
model = PyraTCAiT(embed_dims=EMBED_DIMS, head_counts=HEAD_COUNTS, mlp_ratios=MLP_RATIOS, sr_ratios=SR_RATIOS, depths=DEPTHS,
                  add_classifier=ADD_CLASSIFIER, num_classes=NUM_CLASSES, init_alpha=INIT_ALPHA, init_beta=INIT_BETA, use_improved=USE_IMPROVED)

print(model)

Stage 0
Name                                                                   | Params       | Size                
--------------------------------------------------------------------------------------------------------------
patcher.patch_conv.weight                                              |          576 | (12, 3, 4, 4)       
patcher.patch_conv.bias                                                |           12 | (12,)               
patcher.norm.weight                                                    |           12 | (12,)               
patcher.norm.bias                                                      |           12 | (12,)               
anchor_pos_encoder.pos_embedding                                       |        37632 | (1, 3136, 12)       
positive_pos_encoder.pos_embedding                                     |        37632 | (1, 3136, 12)       
negative_pos_encoder.pos_embedding                                     |        37632 | (1, 3136, 12)       
transform

In [None]:
# setup datasets and dataloaders
train_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_TRAIN_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)

valid_dataset = TestTriplets(batch_size=BATCH_SIZE, num_batches=NUM_VALID_BATCHES, num_channels=IMG_CHANNELS, dim=IMG_DIM)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE)

In [None]:
len(valid_dataset)

300

In [None]:
# setup optimizer
optimizer = optim.Adam(model.parameters())

In [None]:
# setup loss function
loss_fn = TCLoss() if ADD_CLASSIFIER else TripletLoss()

In [None]:
# set up datadistiller
distiller = DataDistiller(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, model=model, loss_fn=loss_fn, optimizer=optimizer, nepochs=NUM_EPOCHS, nclasses=NUM_CLASSES, save_best_weights=SAVE_BEST_WEIGHTS, save_fp=SAVE_FP, device=DEVICE, gpu_id=GPU_ID)

In [None]:
# perform training/validation
distiller.run_main_loop()


---------------------------------------------------------------------------------------------
EPOCH [0/2]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [02:23<00:00,  7.55s/it, loss=0.732]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:31<00:00,  1.67s/it, loss=0.755]



---------------------------------------------------------------------------------------------
EPOCH [1/2]
---------------------------------------------------------------------------------------------


Training, Batch [18/19]: 100%|██████████| 19/19 [01:22<00:00,  4.36s/it, loss=0.723]
Validation, Batch [18/19]: 100%|██████████| 19/19 [00:31<00:00,  1.64s/it, loss=0.732]


BEST VALIDATION MODEL LOSS: 0.7323

Attempting to save best model weights...
	Save successful!



