## Setup


In [None]:
import torch
from student_v1 import StudentEncoderBase_V2
from lib.loss import get_loss_functions
import torch.optim as optim
from lib.data_handlers.CLIC_dataset import build_trainloader
from tqdm import tqdm
import matplotlib.pyplot as plt


import NeuralCompression.neuralcompression.functional as ncF
device = "cuda" if torch.cuda.is_available() else "cpu"



In [2]:
# Import GAN Model
model = torch.hub.load("facebookresearch/NeuralCompression", "msillm_quality_3", force_reload=True)
model = model.eval()
model.update()
model.update_tensor_devices("compress")

# Freeze Model
for p in model.parameters():
    p.requires_grad = False

# Setup Teacher/Student
teacher = model.encoder
student = StudentEncoderBase_V2()
# student.load_state_dict(torch.load("/workspace/unmounted/models/model_35ep.pth"))

Downloading: "https://github.com/facebookresearch/NeuralCompression/zipball/main" to /root/.cache/torch/hub/main.zip


In [3]:
msssim_loss, vgg_perceptual, distillation_loss = get_loss_functions()
vgg_perceptual = vgg_perceptual.to(device)

In [4]:
## Grid Search Params
alpha_hint1 = 0.01
alpha_hint2 = 0.035
alpha_hint3 = 0.035
alpha_hint4 = 0.035
alpha_hint4 = 0.035
beta_latent = 0.7
gamma_msssim = 0.01
gamma_perc = 0.001
learning_rate = 0.0085

## Training

In [5]:
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)

In [6]:
class LossLogger:
    def __init__(self):
        self.loss_names = [
            'avg_hint1_loss', 'avg_hint2_loss', 'avg_hint3_loss', 'avg_hint4_loss', 'avg_hint5_loss',
            'avg_latent_loss', 'avg_ssim_loss', 'avg_perc_loss', 'epoch_loss'
        ]
        self.losses = {name: [] for name in self.loss_names}

    def log(self, loss_tuple):
        assert len(loss_tuple) == len(self.loss_names), "Mismatch in number of losses"
        for name, value in zip(self.loss_names, loss_tuple):
            self.losses[name].append(value)

    def plot(self):
        plt.figure(figsize=(12, 8))
        for name, values in self.losses.items():
            plt.plot(values, label=name)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Losses over Epochs")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def save(self, path):
        torch.save(self.losses, path)

    def load(self, path):
        self.losses = torch.load(path)
        assert all(name in self.losses for name in self.loss_names), "Loaded file missing some loss keys"



In [7]:
logger = LossLogger()

In [None]:
model.decoder.to(device)
student.to(device)
vgg_perceptual = vgg_perceptual.to(device)

def train_epoch(dataloader, epoch=None):
    student.train()
    running_loss = 0.0
    total_hint1_loss = 0.0
    total_hint2_loss = 0.0
    total_hint3_loss = 0.0
    total_hint4_loss = 0.0
    total_hint5_loss = 0.0
    total_latent_loss = 0.0
    total_ssim_loss = 0.0
    total_perc_loss = 0.0


    student.to(device)
    # Add TQDM loader
    loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch if epoch is not None else ''}")

    for i, batch in loop:
        
        x = batch["images"]
        x = x.to(device)
        # Padding Correction        
        x, (_, _) = ncF.pad_image_to_factor(x, model._factor)
         
        optimizer.zero_grad()

        y = student.block1(x)
        
        y = y.to(device)
        z1 = batch['layer1'].to(device)
        z2 = batch['layer2'].to(device)
        z3 = batch['layer3'].to(device)
        z4 = batch['layer4'].to(device)
        z5 = batch['layer5'].to(device)
        z6 = batch['layer6'].to(device)
        
        
        hint1_loss = distillation_loss(y, z1)

        y = student.block2(y)
        hint2_loss = distillation_loss(y, z2)
        
        y = student.block3(y)
        hint3_loss = distillation_loss(y, z3)
        
        y = student.block4(y)
        hint4_loss = distillation_loss(y, z4)
        
        y = student.block5(y)
        hint5_loss = distillation_loss(y, z5)
        
        y = student.block6(y)
        latent_loss = distillation_loss(y, z6)

        
        x_recon = model.decoder(y)
            
        perc_loss = vgg_perceptual(x, x_recon)
        ssim_loss = msssim_loss(x, x_recon)

        # Cumulative Loss
        loss = (1.0 * hint1_loss
                + 1.0 * hint2_loss
                + 1.0 * hint3_loss
                + 1.0 * hint4_loss
                + 1.0 * hint5_loss
                + 1.0 * latent_loss
                + gamma_msssim * ssim_loss
                + gamma_perc * perc_loss)

        # Backprop and optimize
        loss.backward()
        optimizer.step()
        
        # Accumulate loss values
        running_loss += loss.item() * x.size(0)
        total_hint1_loss += hint1_loss.item() * x.size(0)
        total_hint2_loss += hint2_loss.item() * x.size(0)
        total_hint3_loss += hint3_loss.item() * x.size(0)
        total_hint4_loss += hint4_loss.item() * x.size(0)
        total_hint5_loss += hint5_loss.item() * x.size(0)
        
        total_latent_loss += latent_loss.item() * x.size(0)
        total_ssim_loss += ssim_loss.item() * x.size(0)
        total_perc_loss += perc_loss.item() * x.size(0)

    # Average losses
    dataset_size = len(dataloader.dataset)
    epoch_loss = running_loss / dataset_size
    
    avg_hint1_loss = total_hint1_loss / dataset_size
    avg_hint2_loss = total_hint2_loss / dataset_size
    avg_hint3_loss = total_hint3_loss / dataset_size
    avg_hint4_loss = total_hint4_loss / dataset_size
    avg_hint5_loss = total_hint5_loss / dataset_size
    
    avg_latent_loss = total_latent_loss / dataset_size
    avg_ssim_loss = total_ssim_loss / dataset_size
    avg_perc_loss = total_perc_loss / dataset_size
    
    
    print("Component-Wise Loss")
    print("Hint 1: ", avg_hint1_loss)
    print("Hint 2: ", avg_hint2_loss)
    print("Hint 3: ", avg_hint3_loss)
    print("Hint 4: ", avg_hint4_loss)
    print("Hint 5: ", avg_hint5_loss)
    print("Latent Loss: ", avg_latent_loss)
    print("SSIM Loss: ", avg_ssim_loss)
    print("VGG Loss: ", avg_perc_loss)
    
    
    
    # Plot reconstructed image after the epoch
    x_vis = x[0].detach().cpu().permute(1, 2, 0)
    x_recon_vis = x_recon[0].detach().cpu().permute(1, 2, 0)

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(x_vis)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    axs[1].imshow(x_recon_vis)
    axs[1].set_title("Reconstructed Image")
    axs[1].axis('off')

    plt.suptitle(f"Reconstruction at Epoch {epoch}")
    plt.show()

    return avg_hint1_loss, avg_hint2_loss, avg_hint3_loss, avg_hint4_loss, avg_hint5_loss, avg_latent_loss, avg_ssim_loss, avg_perc_loss, epoch_loss 


## Training

In [None]:
from lib.data_handlers.CLIC_dataset import build_activation_dataloader
loader = build_activation_dataloader(dir="/workspace/unmounted/CLIC_activations/ILLM_Q3_torch")

In [10]:
# Example training loop
num_epochs = 200

for epoch in range(num_epochs):
    #torch.cuda.empty_cache()  # Frees cached memory (not allocated memory)
    train_loss = train_epoch(loader, epoch=epoch)
    logger.log(train_loss)
    scheduler.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss[0]:.4f}")
    #torch.cuda.synchronize()

Epoch 0:  28%|██▊       | 44/157 [00:46<01:59,  1.05s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3548, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1915/2609219322.py", line 6, in <module>
    train_loss = train_epoch(loader, epoch=epoch)
  File "/tmp/ipykernel_1915/3797491259.py", line 22, in train_epoch
    for i, batch in loop:
  File "/usr/local/lib/python3.10/dist-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/us

In [41]:
torch.save(student.state_dict(), "/workspace/unmounted/models/model800k_200ep_grid.pth")
logger.save("/workspace/unmounted/runs/model800k_200ep_grid_loss.pt")

## Quantization Aware Training

In [None]:
# Custom QAT configuration (adjust as needed)
qat_config = QConfig(
    activation=default_fake_quant.with_args(observer=torch.ao.quantization.MovingAverageMinMaxObserver,
                                           quant_min=0,
                                           quant_max=255,
                                           dtype=torch.quint8),
    weight=default_weight_fake_quant.with_args(observer=torch.ao.quantization.MinMaxObserver,
                                              quant_min=-128,
                                              quant_max=127,
                                              dtype=torch.qint8)
)

# Apply configuration to student model
student.qconfig = qat_config

In [None]:
# Prepare model for QAT (inserts fake quantization modules)
student_prepared = prepare_qat(student, inplace=False).to(device)
student = student_prepared
# If using FX Graph Mode (recommended for complex models):
# qconfig_mapping = get_default_qat_qconfig_mapping()
# student_prepared = prepare_fx(student, qconfig_mapping, example_inputs=torch.randn(1,3,224,224).to(device)

In [None]:
# Set to evaluation mode and convert
student_prepared.eval()
student_quantized = convert(student_prepared, inplace=False)

# For FX Graph Mode:
# student_quantized = convert_fx(student_prepared)

# Verify quantization
print(student_quantized)