## Create a Dataset of Logits

In [1]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class TeacherLatentSaver:
    def __init__(self, teacher_model, dataloader, save_dir, device):
        self.teacher = teacher_model.eval().to(device)
        self.dataloader = dataloader
        self.save_dir = save_dir
        self.device = device

        os.makedirs(save_dir, exist_ok=True)

    def save_latents(self):
        with torch.no_grad():
            for idx, x in enumerate(self.dataloader):
                x = x.to(self.device)
                latent = self.teacher(x).cpu()
                torch.save(latent, os.path.join(self.save_dir, f"{idx:05d}.pt"))

In [3]:
path = "/home/iot/Desktop/IoT/datasets/CLIC/archive/val2017"

# 2. Create custom Dataset class
class CLICDataset(Dataset):
    def __init__(self, root_dir, transform=None, split='train'):
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.image_files = [f for f in os.listdir(self.root_dir)
                           if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

# 3. Set up transformations # LOOKUP the original PAPER's transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

from torchvision.transforms import (
    CenterCrop,
    Compose,
    RandomChoice,
    RandomCrop,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)


def default_train_transform(image_size: int) -> Compose:
    choice_transform = RandomChoice(
        [
            RandomCrop(size=image_size, pad_if_needed=True, padding_mode="reflect"),
            RandomResizedCrop(size=image_size),
        ]
    )
    return Compose(
        [
            choice_transform,
            RandomHorizontalFlip(),
            ToTensor(),
        ]
    )

transform = default_train_transform(224)

# 4. Create datasets and dataloaders
batch_size = 32

train_dataset = CLICDataset(root_dir= os.path.join(path),
                           transform=transform,
                           split='train')

train_loader = DataLoader(train_dataset,
                         batch_size=batch_size,
                         shuffle=True,)

# Test the dataloader
batch = next(iter(train_loader))
print(f"Batch shape: {batch.shape}")  # Should be [batch_size, 3, 256, 256]
print(f"No of Batches = {len(train_loader)}")

Batch shape: torch.Size([32, 3, 224, 224])
No of Batches = 157


In [4]:
model = torch.hub.load("facebookresearch/NeuralCompression", "msillm_quality_vlo1")
#model = model.to(device)
model = model.eval()
model.update()
model.update_tensor_devices("compress")
teacher = model.encoder
teacher.to(device)

Using cache found in /home/iot/.cache/torch/hub/facebookresearch_NeuralCompression_main


HiFiCEncoder(
  (blocks): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 60, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
      (1): ChannelNorm2D()
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(60, 120, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ChannelNorm2D()
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(120, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ChannelNorm2D()
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(240, 480, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ChannelNorm2D()
      (2): ReLU()
    )
    (4): Sequential(
      (0): Conv2d(480, 960, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ChannelNorm2D()
      (2): ReLU()
    )
    (5): Conv2d(960, 220, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
teacher.eval()

saver = TeacherLatentSaver(
    teacher_model=teacher,
    dataloader=train_loader,
    save_dir="ILLM_VLO1_Train_Latents",  # directory to save .pt files
    device=device
)

saver.save_latents()