## Setup


In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import NeuralCompression.neuralcompression.functional as ncF


from lib.illm import get_teacher_decoder
from lib.student_v1 import StudentEncoderBase_V2
from lib.loss import get_loss_functions
from lib.CLIC_dataset import build_trainloader
from lib.logger import LossLogger
from lib.train import train_epoch
from lib.CLIC_dataset import build_activation_dataloader
from lib.activation_generator import extract_teacher_features


device = "cuda" if torch.cuda.is_available() else "cpu"




In [None]:
# Initialize Models
teacher = get_teacher_decoder(quality=3)
student = StudentEncoderBase_V2()
# student.load_state_dict(torch.load("/workspace/unmounted/models/model_35ep.pth"))


## HyperParams
hyperparams = {
    "alpha_hint1": 0.01,
    "alpha_hint2": 0.035,
    "alpha_hint3": 0.035,
    "alpha_hint4": 0.035,
    "alpha_hint5": 0.035,
    "beta_latent": 0.7,
    "gamma_msssim": 0.01,
    "gamma_perc": 0.001,
    "learning_rate": 0.0085
}

## Optimizer and Scheduler
optimizer = optim.Adam(student.parameters(), lr=hyperparams["learning_rate"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)


## Logger and Directory to save images
logger = LossLogger()

## Training

In [3]:
logits_dir = r"C:\Users\IOT_project\Desktop\Quick_Access_CLIC\ILLM_Q3_B32_torch"

extract_teacher_features(
    batch_size=32,
    quality=3,
    output_dir=logits_dir
)

100%|██████████| 157/157 [06:07<00:00,  2.34s/it]


In [5]:
# Example training loop
num_epochs = 200
dataloader = build_activation_dataloader(dir=logits_dir, generate=False)

In [None]:
directory = r"D:\IOT_project\runs\test_run"


for epoch in range(num_epochs):
    
    train_loss = train_epoch(
        student=student,
        teacher_decoder=teacher,
        dataloader=dataloader,
        optimizer=optimizer,
        hyperparams=hyperparams, 
        epoch=epoch,
        save_dir=directory
    )
    
    logger.log(train_loss)
    scheduler.step()

Epoch 0: 100%|██████████| 157/157 [01:28<00:00,  1.78it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.13511834..0.7928622].


Component-Wise Loss
Hint 1:  4.326128501800975
Hint 2:  2.142341914450287
Hint 3:  2.5652066594476155
Hint 4:  3.5316021150084818
Hint 5:  7.325824903834397
Latent Loss:  3.4805395781614217
SSIM Loss:  10.263913601067415
VGG Loss:  2666.8882792284535


Epoch 1: 100%|██████████| 157/157 [01:34<00:00,  1.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.14958364..1.0555358].


Component-Wise Loss
Hint 1:  4.074044897298145
Hint 2:  1.9967529914181703
Hint 3:  2.1586000858598453
Hint 4:  2.93766942753154
Hint 5:  5.429559793441919
Latent Loss:  2.090475540631896
SSIM Loss:  7.7651453078932064
VGG Loss:  2387.5574939509106


Epoch 2: 100%|██████████| 157/157 [01:22<00:00,  1.90it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.013217159..0.68524224].


Component-Wise Loss
Hint 1:  3.8710930878949013
Hint 2:  1.895699362845937
Hint 3:  1.8907168875833986
Hint 4:  2.645295735377415
Hint 5:  4.685812087575341
Latent Loss:  1.6606666648843487
SSIM Loss:  6.135618067091437
VGG Loss:  2128.852436430135


Epoch 3: 100%|██████████| 157/157 [01:19<00:00,  1.97it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.034390666..1.0439644].


Component-Wise Loss
Hint 1:  3.7518154018244165
Hint 2:  1.8722316216511332
Hint 3:  1.790192828436566
Hint 4:  2.5765479674005203
Hint 5:  4.375722024850784
Latent Loss:  1.3518628686856313
SSIM Loss:  5.08157299430507
VGG Loss:  1899.339336030802


Epoch 4: 100%|██████████| 157/157 [01:21<00:00,  1.92it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.087634526..1.0902126].


Component-Wise Loss
Hint 1:  3.6405344905367323
Hint 2:  1.842165416593005
Hint 3:  1.646805118223664
Hint 4:  2.416073217513455
Hint 5:  4.0532939202466585
Latent Loss:  1.0704817787097518
SSIM Loss:  4.384900068781178
VGG Loss:  1697.4690221312699


Epoch 5: 100%|██████████| 157/157 [01:20<00:00,  1.94it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.016386934..0.99101424].


Component-Wise Loss
Hint 1:  3.5674692331605655
Hint 2:  1.8228726605321193
Hint 3:  1.5823857167344184
Hint 4:  2.3302847427927005
Hint 5:  3.7700119333662045
Latent Loss:  0.9604956722183592
SSIM Loss:  4.052030359863475
VGG Loss:  1585.021964103553


Epoch 6: 100%|██████████| 157/157 [01:19<00:00,  1.97it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.11448652..1.0893776].


Component-Wise Loss
Hint 1:  3.510013366960416
Hint 2:  1.7945427953437636
Hint 3:  1.5482511387509146
Hint 4:  2.2597914873414737
Hint 5:  3.5483469143035307
Latent Loss:  0.8443885663892053
SSIM Loss:  3.799088927590923
VGG Loss:  1486.3383493605693


Epoch 7: 100%|██████████| 157/157 [01:17<00:00,  2.04it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.050597616..1.1086583].


Component-Wise Loss
Hint 1:  3.4533381594973767
Hint 2:  1.7577035134764993
Hint 3:  1.504260913012134
Hint 4:  2.177101954153389
Hint 5:  3.351340670874164
Latent Loss:  0.7676235991678421
SSIM Loss:  3.6369282910778264
VGG Loss:  1416.449191925632


Epoch 8: 100%|██████████| 157/157 [01:17<00:00,  2.01it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0673063..1.0579479].


Component-Wise Loss
Hint 1:  3.4091963991997347
Hint 2:  1.7241622036809374
Hint 3:  1.4613481399359975
Hint 4:  2.1019121609675655
Hint 5:  3.2150345114386005
Latent Loss:  0.6913899809691557
SSIM Loss:  3.4046224515149546
VGG Loss:  1326.4176849559615


Epoch 9: 100%|██████████| 157/157 [01:21<00:00,  1.93it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.06221377..0.84609705].


Component-Wise Loss
Hint 1:  3.36872857922961
Hint 2:  1.7041416105571066
Hint 3:  1.4313841215364493
Hint 4:  2.068936211478179
Hint 5:  3.1156378468130805
Latent Loss:  0.6574027024826427
SSIM Loss:  3.2388357204996097
VGG Loss:  1251.2394059296626


Epoch 10: 100%|██████████| 157/157 [01:20<00:00,  1.95it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.008961737..1.1710408].


Component-Wise Loss
Hint 1:  3.335603348009146
Hint 2:  1.6751622886035094
Hint 3:  1.3956674810048122
Hint 4:  2.0080825424498054
Hint 5:  3.015876837596772
Latent Loss:  0.5941532494346048
SSIM Loss:  3.091427632957507
VGG Loss:  1190.4779490088201


Epoch 11: 100%|██████████| 157/157 [01:16<00:00,  2.04it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.17673886..1.1457171].


Component-Wise Loss
Hint 1:  3.307863556655349
Hint 2:  1.6545206968951378
Hint 3:  1.3777721757721748
Hint 4:  1.9546050696995607
Hint 5:  2.9329410571201593
Latent Loss:  0.5648899292869932
SSIM Loss:  2.997040201903908
VGG Loss:  1143.757645139269


Epoch 12: 100%|██████████| 157/157 [01:18<00:00,  1.99it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.107500546..1.0863568].


Component-Wise Loss
Hint 1:  3.2811407682242666
Hint 2:  1.6297695991719605
Hint 3:  1.3469049805288862
Hint 4:  1.8924958237037537
Hint 5:  2.863532233390079
Latent Loss:  0.5180561496953296
SSIM Loss:  2.8848404610992238
VGG Loss:  1094.5169677734375


Epoch 13: 100%|██████████| 157/157 [01:20<00:00,  1.95it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.11447127..1.108344].


Component-Wise Loss
Hint 1:  3.2614185331733365
Hint 2:  1.6118292846497457
Hint 3:  1.3311352570345447
Hint 4:  1.839960374270275
Hint 5:  2.7986268199932804
Latent Loss:  0.4989002090730485
SSIM Loss:  2.8138202011205586
VGG Loss:  1057.7437394257563


Epoch 14: 100%|██████████| 157/157 [01:17<00:00,  2.03it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.02685254..0.99647427].


Component-Wise Loss
Hint 1:  3.2414861089864355
Hint 2:  1.5925150227015186
Hint 3:  1.3150939478236399
Hint 4:  1.7939240404754688
Hint 5:  2.745613475893713
Latent Loss:  0.464833800959739
SSIM Loss:  2.7388657551662177
VGG Loss:  1023.7457570847432


Epoch 15: 100%|██████████| 157/157 [01:34<00:00,  1.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.06471952..1.000499].


Component-Wise Loss
Hint 1:  3.2237623223833216
Hint 2:  1.570347463249401
Hint 3:  1.2913140548262627
Hint 4:  1.7580798406889484
Hint 5:  2.6879417383746738
Latent Loss:  0.4426118526014553
SSIM Loss:  2.67679982883915
VGG Loss:  994.8800017727408


Epoch 16: 100%|██████████| 157/157 [01:38<00:00,  1.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07616501..0.79625404].


Component-Wise Loss
Hint 1:  3.208167413997043
Hint 2:  1.5539326709546861
Hint 3:  1.2714157039952125
Hint 4:  1.721702942612824
Hint 5:  2.6403202460070325
Latent Loss:  0.4209517399976208
SSIM Loss:  2.6164287700774564
VGG Loss:  968.8707461994925


Epoch 17:   0%|          | 0/157 [00:00<?, ?it/s]

In [None]:
logger.save(directory)
logger.plot()
torch.save(student.state_dict(), f"{directory}/model_{num_epochs}ep.pth")