In [None]:
from transformers import CLIPImageProcessor
from diffusers import AutoencoderTiny
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os
import pandas as pd
from torchvision.io import decode_image
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim 
from torch.nn import TripletMarginLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float32)
vae.to("cuda")

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


AutoencoderTiny(
  (encoder): EncoderTiny(
    (layers): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): AutoencoderTinyBlock(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): ReLU()
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip): Identity()
        (fuse): ReLU()
      )
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (3): AutoencoderTinyBlock(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): ReLU()
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  

In [3]:
"""
data_fp = "scene_data/train-scene classification/train/"

img = Image.open(data_fp + "5.jpg")
processed = processor(img, return_tensors = 'pt')
pixel_values = processed['pixel_values'].to(vae.device).to(dtype=vae.dtype)
with torch.no_grad():
    latents = vae.encode(pixel_values).latents


print("Latent shape:", latents.shape)
"""

'\ndata_fp = "scene_data/train-scene classification/train/"\n\nimg = Image.open(data_fp + "5.jpg")\nprocessed = processor(img, return_tensors = \'pt\')\npixel_values = processed[\'pixel_values\'].to(vae.device).to(dtype=vae.dtype)\nwith torch.no_grad():\n    latents = vae.encode(pixel_values).latents\n\n\nprint("Latent shape:", latents.shape)\n'

In [4]:
"""
decoded_image = vae.decode(latents).sample
tensor_to_pil = transforms.ToPILImage()
pil_image = tensor_to_pil(decoded_image[0])

pil_image
"""

'\ndecoded_image = vae.decode(latents).sample\ntensor_to_pil = transforms.ToPILImage()\npil_image = tensor_to_pil(decoded_image[0])\n\npil_image\n'

In [5]:
class ImageTripletDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        anchor = self.img_labels.iloc[idx]
        positive = self.img_labels[self.img_labels['label'] == anchor['label']].sample(1).iloc[0]
        negative = self.img_labels[self.img_labels['label'] != anchor['label']].sample(1).iloc[0]

        anchor_image = Image.open(os.path.join(self.img_dir, anchor['image_name']))
        positive_image = Image.open(os.path.join(self.img_dir, positive['image_name']))
        negative_image = Image.open(os.path.join(self.img_dir, negative['image_name']))

        a_processed = processor(anchor_image, return_tensors = 'pt')
        p_processed = processor(positive_image, return_tensors = 'pt')
        n_processed = processor(negative_image, return_tensors = 'pt')

        anchor_tensor = a_processed['pixel_values'].to(dtype=vae.dtype)
        positive_tensor = p_processed['pixel_values'].to(dtype=vae.dtype)
        negative_tensor = n_processed['pixel_values'].to(dtype=vae.dtype)

        return anchor_tensor[0], positive_tensor[0], negative_tensor[0]

In [6]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = decode_image(img_path)

        # Define the desired output size (e.g., 224x224)
        desired_size = (150, 150)

        # Create a Resize transform
        resize_transform = transforms.Resize(desired_size)

        # Apply the transform to the decoded image
        resized_image_tensor = resize_transform(image)

        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            resized_image_tensor = self.transform(resized_image_tensor)
        if self.target_transform:
            label = self.target_transform(label)
        return resized_image_tensor, label

In [7]:
img_dir = "scene_data/train-scene classification/train/"
train_annotations = "scene_data/train-scene classification/train.csv"
test_annotations = "scene_data/test.csv"

training_data = CustomImageDataset(train_annotations, img_dir)
test_data = CustomImageDataset(test_annotations, img_dir)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

triplet_data = ImageTripletDataset(train_annotations, img_dir)
triplet_dataloader = DataLoader(triplet_data, batch_size=64, shuffle = True)

In [8]:
"""
# Display image and label.
anchors, positives, negatives = next(iter(triplet_dataloader))
print(f"Feature batch shape: {anchors.size()}")
a_img = anchors[0].squeeze().permute(1,2,0)
plt.imshow(a_img, cmap="gray")
plt.show()

p_img = positives[0].squeeze().permute(1,2,0)
plt.imshow(p_img, cmap="gray")
plt.show()

n_img = negatives[0].squeeze().permute(1,2,0)
plt.imshow(n_img, cmap="gray")
plt.show()
"""

'\n# Display image and label.\nanchors, positives, negatives = next(iter(triplet_dataloader))\nprint(f"Feature batch shape: {anchors.size()}")\na_img = anchors[0].squeeze().permute(1,2,0)\nplt.imshow(a_img, cmap="gray")\nplt.show()\n\np_img = positives[0].squeeze().permute(1,2,0)\nplt.imshow(p_img, cmap="gray")\nplt.show()\n\nn_img = negatives[0].squeeze().permute(1,2,0)\nplt.imshow(n_img, cmap="gray")\nplt.show()\n'

In [None]:
criterion = TripletMarginLoss(margin=1.0, p=2) # p=2 for Euclidean distance
optimizer = optim.Adam(vae.parameters(), lr=0.001)

: 

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    for i, (anchor_data, positive_data, negative_data) in enumerate(triplet_dataloader):
        # Get embeddings from the model

        anchor_embedding = vae.encode(anchor_data.to(vae.device)).latents
        positive_embedding = vae.encode(positive_data.to(vae.device)).latents
        negative_embedding = vae.encode(negative_data.to(vae.device)).latents

        # Calculate Triplet Margin Loss
        loss = criterion(anchor_embedding, positive_embedding, negative_embedding)

        # Backpropagation and Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(triplet_dataloader)}], Loss: {loss.item():.4f}')

print("Training finished.")

Epoch [1/10], Step [10/267], Loss: 0.7947
Epoch [1/10], Step [20/267], Loss: 0.5907
Epoch [1/10], Step [30/267], Loss: 0.4988
Epoch [1/10], Step [40/267], Loss: 0.4655
Epoch [1/10], Step [50/267], Loss: 0.3719
Epoch [1/10], Step [60/267], Loss: 0.4459
Epoch [1/10], Step [70/267], Loss: 0.4961
Epoch [1/10], Step [80/267], Loss: 0.4577
Epoch [1/10], Step [90/267], Loss: 0.4446
Epoch [1/10], Step [100/267], Loss: 0.5730
Epoch [1/10], Step [110/267], Loss: 0.5775
Epoch [1/10], Step [120/267], Loss: 0.5857
Epoch [1/10], Step [130/267], Loss: 0.5175
Epoch [1/10], Step [140/267], Loss: 0.5570
Epoch [1/10], Step [150/267], Loss: 0.4979
Epoch [1/10], Step [160/267], Loss: 0.3647
Epoch [1/10], Step [170/267], Loss: 0.5392
Epoch [1/10], Step [180/267], Loss: 0.4915
Epoch [1/10], Step [190/267], Loss: 0.5509
Epoch [1/10], Step [200/267], Loss: 0.2494
Epoch [1/10], Step [210/267], Loss: 0.4553
Epoch [1/10], Step [220/267], Loss: 0.4466
Epoch [1/10], Step [230/267], Loss: 0.5243
Epoch [1/10], Step [