In [None]:
# Install AK_SSL

!pip install AK_SSL

## Vision Part

In [None]:
# import libraries

from AK_SSL.vision import Trainer
import torch
import torchvision

In [None]:
# load pretext dataset

train_unlabeled_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split="unlabeled",
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

In [None]:
# define backbone and remove the last layer

backbone = torchvision.models.resnet18(weights=None)
feature_size = backbone.fc.in_features
backbone.fc = torch.nn.Identity()

In [None]:
# define Trainer

trainer = Trainer(
    method="barlowtwins",
    backbone=backbone,
    feature_size=feature_size,
    image_size=96,
    save_dir="./save_for_report/",
    checkpoint_interval=50,
    reload_checkpoint=False,
)

In [None]:
# train

trainer.train(
    dataset=train_unlabeled_dataset,
    batch_size=256,
    start_epoch=1,
    epochs=500,
    optimizer="Adam",
    weight_decay=1e-6,
    learning_rate=1e-3,
)

In [None]:
# load evaluate dataset

train_label_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split="train",
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

test_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split="test",
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

In [None]:
# evaluate

trainer.evaluate(
    train_dataset=train_label_dataset,
    test_dataset=test_dataset,
    eval_method="linear",
    top_k=1,
    epochs=100,
    optimizer="Adam",
    weight_decay=1e-6,
    learning_rate=1e-3,
    batch_size=256,
    fine_tuning_data_proportion=1,
)

## Multimodal Part

In [None]:
# Download Flicker Dataset form kaggle

!pip install kaggle
!mkdir ~/.kaggle
!kaggle datasets download -d adityajn105/flickr8k
!unzip flickr8k.zip &> /dev/null

In [None]:
# import libraries

import os
import numpy as np
import pandas as pd

In [None]:
# Reading Captions
df = pd.read_csv("captions.txt")
df.head()

# Directory containing the Flickr8k images
image_dir = "Images"

# Get a list of all image files in the directory
image_files = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]

# Randomly select 10 image files
selected_images = np.random.choice(image_files, size=10, replace=False)

# Print the names of the selected images
print("Selected images:")
for image_file in selected_images:
    print(image_file)

captions_for_selected_images = []
for image_file in selected_images:
    caption = np.random.choice(df[df["image"] == image_file]["caption"])
    captions_for_selected_images.append(caption)

print("---------------------------------------")
print("Captions for selected images:")
for caption in captions_for_selected_images:
    print(caption)

In [None]:
# load bert tokenizer from huggingface

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# Prepare the dataset for the Trainer function

from AK_SSL.multimodal.models.utils.clip import CustomClipDataset, get_image_transform

img_transforms = get_image_transform()

dataset = CustomClipDataset(
    image_files, captions_for_selected_images, tokenizer, img_transforms
)

In [None]:
# load DistilBertModel for text encoder and resnet18 for image encoder

from transformers import DistilBertModel


txt_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")

img_encoder = torchvision.models.resnet18(pretrained=True)

In [None]:
# Define the Trainer

from AK_SSL.multimodal import Trainer

trainer = Trainer(
    method="CLIP",
    image_encoder=img_encoder,
    text_encoder=txt_encoder,
    mixed_precision_training=True,
    save_dir="./save_for_report/",
    checkpoint_interval=50,
    reload_checkpoint=False,
    verbose=True,
    image_feature_dim=0,
    text_feature_dim=768,
    embed_dim=256,
    init_tau=np.log(1.0),
    init_bias=0.0,
    use_siglip=False,
)

In [None]:
# Train the model

trainer.train(
    train_dataset=dataset,
    batch_size=256,
    start_epoch=1,
    epochs=100,
    optimizer="Adam",
    weight_decay=1e-6,
    learning_rate=1e-3,
)