In [89]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, AdamW, get_scheduler
from tqdm import tqdm

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [90]:
# train dataset
data_dir = "/root/20242R0136COSE47402/FinalProject/data/train"
class_candidate = [folder for folder in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, folder))]
text_inputs = []
for folder in os.listdir(data_dir):
    if os.path.isdir(os.path.join(data_dir, folder)):
        folder = folder.replace('_', ' ')
        text_inputs.append(f"a photo of {folder}")

In [91]:
image_paths = []
image_labels = []

for class_name in class_candidate:
    class_folder = os.path.join(data_dir, class_name)
    for img_name in os.listdir(class_folder):
        if img_name[0] == '.':
            continue
        img_path = os.path.join(class_folder, img_name)
        image_paths.append(img_path)
        class_name = class_name.replace('_', ' ')
        image_labels.append(f"a photo of {class_name}")

print(f"train dataset size : {len(image_paths)}")

train dataset size : 75750


In [92]:
def collate_fn(batch):
    inputs, labels = zip(*batch)

    input_ids = [inp['input_ids'] for inp in inputs]
    attention_masks = [inp['attention_mask'] for inp in inputs]
    pixel_values = [inp['pixel_values'] for inp in inputs]

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    pixel_values = torch.stack(pixel_values)

    batch_inputs = {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'pixel_values': pixel_values
    }

    return batch_inputs, labels


In [93]:
def clip_loss(logits_per_image):
    targets = torch.arange(len(logits_per_image), device=logits_per_image.device)

    return (F.cross_entropy(logits_per_image, targets) + F.cross_entropy(logits_per_image.T, targets)) / 2


In [94]:
class Food101DataSet(Dataset):
    def __init__(self, image_paths, image_labels, processor):
        self.image_paths = image_paths
        self.image_labels = image_labels
        self.processor = processor
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.image_labels[idx]
        inputs = self.processor(text=label, images=image, return_tensors='pt')
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs, label

train_dataset = Food101DataSet(image_paths, image_labels, processor)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [95]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [96]:
for param in model.parameters():
    param.requires_grad = False

for param in model.text_model.embeddings.parameters():
    param.requires_grad = True

for param in model.vision_model.embeddings.parameters():
    param.requires_grad = True

In [97]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6)
epochs = 8
num_warmup_steps = int(0.1 * len(train_dataloader)) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=len(train_dataloader) * epochs)



In [99]:
model.train()
losses = []

for epoch in range(epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1} / {epochs}")
    for batch in progress_bar:
        batch_inputs, labels = batch 
        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}

        outputs = model(**batch_inputs)

        logits_per_image = outputs.logits_per_image
        loss = clip_loss(logits_per_image)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        progress_bar.set_postfix(loss=loss.item())
    print(f"epoch {epoch + 1} finished.")

print("Train completed.")

Epoch 1 / 8: 100%|██████████| 9469/9469 [21:38<00:00,  7.29it/s, loss=0.0083]  


epoch 1 finished.


Epoch 2 / 8: 100%|██████████| 9469/9469 [21:32<00:00,  7.32it/s, loss=0.0103]  


epoch 2 finished.


Epoch 3 / 8: 100%|██████████| 9469/9469 [21:30<00:00,  7.34it/s, loss=0.277]   


epoch 3 finished.


Epoch 4 / 8: 100%|██████████| 9469/9469 [21:43<00:00,  7.26it/s, loss=0.0964]  


epoch 4 finished.


Epoch 5 / 8: 100%|██████████| 9469/9469 [21:34<00:00,  7.32it/s, loss=0.0119]  


epoch 5 finished.


Epoch 6 / 8: 100%|██████████| 9469/9469 [21:39<00:00,  7.28it/s, loss=0.0234]  


epoch 6 finished.


Epoch 7 / 8: 100%|██████████| 9469/9469 [21:36<00:00,  7.31it/s, loss=0.0115]  


epoch 7 finished.


Epoch 8 / 8: 100%|██████████| 9469/9469 [21:34<00:00,  7.31it/s, loss=0.263]   

epoch 8 finished.
Train completed.





In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)
plt.show()

In [16]:
# test dataset
data_dir = "/Users/anjonghyeon/Desktop/KU/3-2/DeepLearning/20242R0136COSE47402/FinalProject/data/test"
class_candidate = [folder for folder in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, folder))]
text_inputs = []
for folder in os.listdir(data_dir):
    if os.path.isdir(os.path.join(data_dir, folder)):
        folder = folder.replace('_', ' ')
        text_inputs.append(f"a photo of {folder}")

In [17]:
image_paths = []
image_labels = []

for class_name in class_candidate:
    class_folder = os.path.join(data_dir, class_name)
    for img_name in os.listdir(class_folder):
        img_path = os.path.join(class_folder, img_name)
        image_paths.append(img_path)
        class_name = class_name.replace('_', ' ')
        image_labels.append(f"a photo of {class_name}")

print(f"test dataset size : {len(image_paths)}")

test dataset size : 25250


In [None]:
ans = 0
model.eval()

for idx, img_path in enumerate(image_paths):
    image = Image.open(img_path).convert('RGB')
    inputs = processor(text=text_inputs, images=image, return_tensors='pt', padding=True)
    outputs = model(**inputs)

    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    if text_inputs[torch.argmax(probs).item()] == image_labels[idx]:
        ans += 1
    if idx % 100 == 0:
        print(f"# {idx} finished.")

accuracy = ans / len(image_paths) * 100
accuracy