In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
import os
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import json
import clip

%config InlineBackend.figure_format = 'retina'

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
model, preprocess = clip.load("ViT-L/14@336px")
model.to(device).eval()

In [None]:
seed_value = 42

# PyTorch 시드 설정
torch.manual_seed(seed_value)

# GPU 사용 시
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

In [None]:
class image_text_dataset():
    def __init__(self, list_image_path,list_txt):
        # Initialize image paths and corresponding texts
        self.image_path = list_image_path
        # Tokenize text using CLIP's tokenizer
        self.title  = clip.tokenize(list_txt)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = preprocess(Image.open(self.image_path[idx]))
        title = self.title[idx]
        return image, title

In [None]:
label_path = os.listdir(os.getenv('DATA_FOR_TUNING') + '/labels/')
label_path = sorted(label_path)
image_path = os.getenv('DATA_FOR_TUNING') + '/images/'
image_list = []
label_list = []
for path in label_path:
  p = os.getenv('DATA_FOR_TUNING') + '/labels/' + path
  with open(p, 'r') as f:
    json_data = json.load(f)
    image_list.append(image_path + json_data['images'][0]['file_name'])
    label_list.append(json_data['annotations'][0]['english'])

In [None]:
# 데이터셋 크기 계산
dataset_size = len(dataset)
train_size = int(dataset_size * 0.6)
validation_size = int(dataset_size * 0.2)
test_size = dataset_size - (train_size + validation_size)

# 데이터셋 분할
train_dataset, validation_dataset, test_dataset = random_split(dataset, [train_size, validation_size, test_size])

# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [None]:
# 학습 루프
num_epochs = 5  # 에폭 수
best_val_loss = 1e5

for epoch in range(num_epochs):
    # 학습 단계
    model.train()  # 모델을 학습 모드로 설정
    for images, texts in train_loader:
        images, texts = images.to(device), texts.to(device)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(images.size(0), dtype=torch.long, device=device)

        loss = (criterion(logits_per_image, ground_truth) +
                criterion(logits_per_text, ground_truth)) / 2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item()}")

    # 검증 단계
    model.eval()  # 모델을 평가 모드로 설정
    total_loss = 0
    with torch.no_grad():  
        for images, texts in validation_loader:  
            images, texts = images.to(device), texts.to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(images.size(0), dtype=torch.long, device=device)

            loss = (criterion(logits_per_image, ground_truth) +
                    criterion(logits_per_text, ground_truth)) / 2
            total_loss += loss.item()

    avg_loss = total_loss / len(validation_loader)

    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        torch.save(model.state_dict(), "best_model.pt")

    print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_loss}")


In [None]:
model.load_state_dict(torch.load("best_model.pt"))

In [None]:
# 테스트 단계
model.eval()  # 모델을 평가 모드로 설정
total_loss = 0
total_correct = 0
total_images = 0

with torch.no_grad():  
    for images, texts in test_loader:  
        images, texts = images.to(device), texts.to(device)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(images.size(0), dtype=torch.long, device=device)

        loss = (criterion(logits_per_image, ground_truth) +
                criterion(logits_per_text, ground_truth)) / 2
        total_loss += loss.item()

        # 정확도 계산
        _, predicted = torch.max(logits_per_image, 1)
        total_correct += (predicted == ground_truth).sum().item()
        total_images += images.size(0)

avg_loss = total_loss / len(test_loader)
accuracy = total_correct / total_images * 100
print(f"Test Loss: {avg_loss}, Accuracy: {accuracy}%")