In [1]:
from PIL import Image
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPModel
import os

import torch
from tqdm import tqdm
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
model = CLIPModel.from_pretrained(
    "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.float16
).to(device)
processor = CLIPProcessor.from_pretrained(
    "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.float16
)
model_vision = model.vision_model
visual_projection = model.visual_projection
model_text = model.text_model
text_projection = model.text_projection

In [4]:
model_vision.requires_grad_(False)
model_text.requires_grad_(False)
text_projection.requires_grad_(False)
model_text.eval()
model_vision.eval()
text_projection.eval()

Linear(in_features=1024, out_features=1024, bias=False)

In [5]:
class dataset:
    def __init__(self, path):
        print(f"Prepare data for {path}")
        self.animal_type = os.listdir(path)
        self.animal_type.sort()
        animal_type_prefix = [f"a photo of {animal}" for animal in self.animal_type]
        self.images = []
        for i in tqdm(range(len(self.animal_type))):
            animal = self.animal_type[i]
            image_list = os.listdir(f"{path}/{animal}")
            image_list.sort()
            images = [Image.open(f"{path}/{animal}/{file}") for file in image_list]
            inputs = processor(
                text=animal_type_prefix,
                images=images,
                return_tensors="pt",
                padding=True,
            )
            self.images.append(inputs.pixel_values)
        inputs = inputs.to(device)
        self.titles = text_projection(
            model_text(input_ids=inputs.input_ids).pooler_output
        )
        self.images = torch.cat(self.images)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = self.images[idx]
        title = self.titles[idx // (len(self.images) // len(self.animal_type))]
        return (
            image,
            title,
        )

In [6]:
train_dataset = dataset(path="topic2_release/train")
test_dataset = dataset(path="topic2_release/test")

Prepare data for topic2_release/train


100%|██████████| 10/10 [00:18<00:00,  1.89s/it]


Prepare data for topic2_release/test


100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


In [7]:
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-4
)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
MSELoss = torch.nn.MSELoss()

In [8]:
num_epochs = 10
batch_size = 100
start = 0
loss_training = []
loss_testing = []

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [9]:
model_dir_name = 'output_model'
if not os.path.isdir(model_dir_name):
    os.mkdir(model_dir_name)

In [10]:
is_resume = False

if is_resume:
    resume_epoch = 0
    start = resume_epoch + 1
    model = CLIPModel.from_pretrained(
        f"{model_dir_name}/epoch_{resume_epoch}", torch_dtype=torch.float16
    ).to(device)
    optimizer.load_state_dict(
        torch.load(f"{model_dir_name}/epoch_{resume_epoch}/optimizer.bin")
    )
    scheduler.load_state_dict(
        torch.load(f"{model_dir_name}/epoch_{resume_epoch}/scheduler.bin")
    )
    with open(f"{model_dir_name}/epoch_{resume_epoch}/loss_history.pkl", "rb") as handle:
        save_loss = pickle.load(handle)
        loss_training = save["training"]
        loss_testing = save["testing"]

In [11]:
test_batches = [batch for batch in test_dataloader]

In [None]:
for epoch in range(start, num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for i, batch in enumerate(pbar):
        # model_vision.train()
        visual_projection.train()
        optimizer.zero_grad()

        images, texts = batch

        images = images.to(device)
        texts = texts.to(device)

        # Forward pass
        # output = model(pixel_values=images, input_ids=texts)
        x = model_vision(pixel_values=images)
        x = visual_projection(x.pooler_output)
        y = texts
        # y = model_text(input_ids=texts)
        # y = text_projection(y.pooler_output)
        # Compute loss
        loss = MSELoss(x, y)

        # Backward pass
        loss.backward()
        optimizer.step()
        loss_training.append(loss.item())

        if i % 10 == 0:
            # model_vision.eval()
            visual_projection.eval()
            test_batch = test_batches[i // 10]
            optimizer.zero_grad()

            mages, texts = test_batch
            images = images.to(device)
            texts = texts.to(device)

            # Forward passs
            x = model_vision(pixel_values=images)
            x = visual_projection(x.pooler_output)
            y = texts
            # y = model_text(input_ids=texts)
            # y = text_projection(y.pooler_output)
            
            loss = MSELoss(x, y)
            loss_testing.append(loss.item())

        pbar.set_description(f"Epoch: {epoch}/{num_epochs}, Training loss: {loss_training[-1]:.5f}, Testing loss: {loss_testing[-1]:.5f}")
    scheduler.step()
    
    model.save_pretrained(f"{model_dir_name}/epoch_{epoch}/")
    torch.save(
        optimizer.state_dict(),
        f"{model_dir_name}/epoch_{epoch}/optimizer.bin",
    )
    torch.save(
        scheduler.state_dict(),
        f"{model_dir_name}/epoch_{epoch}/scheduler.bin",
    )  

    with open(f"output_model/epoch_{epoch}/loss_history.pkl", "wb") as handle:
        save_loss = {"training": loss_training, "testing": loss_testing}
        pickle.dump(save_loss, handle, protocol=pickle.HIGHEST_PROTOCOL)