In [1]:
# mlflow ui --port 6010 --backend-store-uri file:/share/lazy/will/ConstrastiveLoss/Logs
# watch -n 0.5 nvidia-smi

In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from VQVAE import VQVAE_Encoder as small_model
from VQVAE import VQVAE as big_model

from train import knowledge_distillation
from utilities import start_mlflow_experiment, Params, save_to_mlflow, count_parameters, load_full_state, select_gpu

from tqdm import tqdm
import mlflow

In [4]:
##### device = select_gpu(1)
args = Params(16, 10, 4e-4, 256, 'cuda:0')

start_mlflow_experiment('VQVAE2 Knowledge distillation', 'lane-finder')


transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

dataset = datasets.ImageFolder('/share/lazy/will/ConstrastiveLoss/Imgs/color_images/train/', transform=transform)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory = True)

teacher_model = big_model(channel=128).to(args.device)

# optimizer declaration does nothing
optimizer = optim.Adam(teacher_model.parameters(), lr=args.lr)
load_full_state(teacher_model, optimizer, '/share/lazy/will/ConstrastiveLoss/Logs/0/64a43ca191944cba89536145c4422027/artifacts/run_stats.pyt', freeze_weights=False)

embed_b = next(teacher_model.quantize_b.buffers())
embed_t = next(teacher_model.quantize_t.buffers())

student_model = small_model(channel=128).to(args.device)

optimizer = optim.Adam(student_model.parameters(), lr=args.lr)

run_name = 'identical encoders, not lookup table'

with mlflow.start_run(run_name = run_name) as run:

    for epoch in range(args.epoch):
        results = knowledge_distillation(epoch, loader, teacher_model, student_model, optimizer, args.device)
        for Dict in results:
            save_to_mlflow(Dict, args)


we also froze 0 weights
Of the 79.0 parameter layers to update in the current model, 79.0 were loaded


  result,i = ctx.saved_variables
epoch: 1; avg mse: 0.02443; lr: 0.00040:  93%|█████████▎| 3374/3642 [1:04:36<05:07,  1.15s/it]


KeyboardInterrupt: 

In [None]:
buffers_s = [i for i in student_model.quantize_b.buffers()]

In [None]:
buffers_s[0]

In [None]:
buffers_t = [i for i in teacher_model.quantize_b.buffers()]

In [None]:
buffers_t[0]