In [2]:

import pandas as pd
from torch.utils.data import Dataset

from torch.utils.data import DataLoader
import torch
from torchvision import transforms
import torch.nn as nn
from torchvision import models
import torch
import pandas as pd
from PIL import Image

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [4]:
class ImageEncoder(nn.Module):
    def __init__(self, embedding_size):
        super(ImageEncoder, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.fc = nn.Linear(self.resnet50.fc.out_features, embedding_size)
        
    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc(x)
        return x

In [5]:
model_name =  'model/model_image_2023-07-08'
image_model = ImageEncoder(768).to(device)
image_model.load_state_dict(torch.load(model_name + '.pth'))
image_model.eval()

ImageEncoder(
  (resnet50): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
    

In [6]:
class FDataset(Dataset):
    def __init__(self, annotations_file):
        self.data = pd.read_csv(annotations_file) 
        self.transform =  transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]
        image = Image.open(img_path)
        if image.mode != 'RGB':
            image = image.convert('L')
            image = Image.merge('RGB', [image] * 3)
        image = self.transform(image)
        caption = self.data.iloc[idx, 1]
        return image, caption

In [7]:
dataset = FDataset('./data/anotation_new.csv')
dataloader =  DataLoader(dataset, batch_size=256, shuffle=False, pin_memory=True, num_workers=0)
image_embeddings = torch.Tensor([])


In [8]:
for i, (img, _) in enumerate(dataloader):
    # 予測と損失の計算
    img = img.to(device)
    with torch.no_grad():
        image_embedding = image_model(img).cpu()
    image_embeddings = torch.cat((image_embeddings, image_embedding), dim=0)
    if i % 100 == 0:
        print(f"finish: {len(image_embeddings) * 100/len(dataset)}%")

finish: 0.02015133970986794%
finish: 2.0352853106966617%
finish: 4.050419281683456%
finish: 6.06555325267025%
finish: 8.080687223657042%
finish: 10.095821194643836%
finish: 12.11095516563063%
finish: 14.126089136617425%
finish: 16.14122310760422%
finish: 18.15635707859101%
finish: 20.171491049577806%
finish: 22.1866250205646%
finish: 24.201758991551394%
finish: 26.216892962538186%
finish: 28.23202693352498%
finish: 30.247160904511773%
finish: 32.26229487549857%
finish: 34.277428846485364%
finish: 36.29256281747216%
finish: 38.30769678845895%
finish: 40.32283075944574%
finish: 42.33796473043254%
finish: 44.353098701419334%
finish: 46.36823267240612%
finish: 48.38336664339292%
finish: 50.398500614379714%
finish: 52.41363458536651%
finish: 54.4287685563533%
finish: 56.44390252734009%
finish: 58.45903649832689%
finish: 60.474170469313684%
finish: 62.48930444030047%
finish: 64.50443841128727%
finish: 66.51957238227406%
finish: 68.53470635326086%
finish: 70.54984032424765%
finish: 72.5649742

In [10]:
torch.save(image_embeddings, f'image_tensor/model_image_2023-07-08/resnet50_tensor.pt')