In [26]:
import pandas as pd
import torch
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import json

In [28]:
Image.open("lfw-deepfunneled/Aaron_Eckhart/Aaron_Eckhart_0001.jpg").size

(250, 250)

In [19]:
df_train = pd.read_csv("dataset_train_receipt.csv")
df_val = pd.read_csv("dataset_val_receipt.csv")
with open('label2idx.json') as f:
    label2idx = json.load(f)

In [None]:
class SiameseDataset(Dataset):
    def __init__(self, df, label2idx, is_train:bool):
        self.df = df.values
        self.label2idx = label2idx
        self.is_train = is_train
        self.resize = transforms.Resize((224,224))
        
    def transform(self, x):
        x = Image.open(x)
        x = self.resize(x)
        return x
        
    def __getitem__(self, index):
        label, img1, img2 = self.df[index]
        if self.is_train:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, torch.from_numpy(np.array(self.label2idx[label], dtype=np.float32))

    def __len__(self):
        return len(self.df)

In [None]:


# Siamese Network 모델 정의
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.convolution = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=7),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 53 * 53, 500),
            nn.ReLU(inplace=True),
            nn.Linear(500, 500),
            nn.ReLU(inplace=True),
            nn.Linear(500, 5)  # 5-dimensional representation
        )

    def forward_one(self, x):
        x = self.convolution(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

# 이미지 데이터셋을 위한 클래스 정의


# Siamese Network 학습
def train_siamese_network(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        for img1, img2, label in dataloader:
            optimizer.zero_grad()
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2, label)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# Siamese Network의 Contrastive Loss 정의
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

# 데이터셋과 DataLoader 설정
# (이 부분은 실제 데이터셋과 맞게 수정해야 합니다.)
transform = transforms.Compose([transforms.ToTensor()])
dataset = SiameseDataset(data, transform=transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=64)

# Siamese Network 및 학습 관련 설정
siamese_net = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.0005)

# Siamese Network 학습
train_siamese_network(siamese_net, dataloader, criterion, optimizer, num_epochs=10)

# 데이터 임베딩 추출
embeddings = []
labels = []
siamese_net.eval()
with torch.no_grad():
    for img, _, label in dataloader:
        output = siamese_net.forward_one(img)
        embeddings.append(output.numpy())
        labels.append(label.numpy())

embeddings = np.concatenate(embeddings)
labels = np.concatenate(labels)

# K-Means 클러스터링 수행
kmeans = KMeans(n_clusters=3)  # 클러스터의 개수는 데이터셋에 따라 적절하게 수정
clusters = kmeans.fit_predict(embeddings)

# 클러스터 결과 출력
print("Cluster Assignments:")
print(clusters)
