In [54]:
import tqdm
import torchvision.models.resnet as resnet
import torch
import os
import argparse
from dotenv import load_dotenv, find_dotenv
from src.data.dataloader import UseMetaData, ValTransforms
from torch.utils.data import DataLoader
import pickle

In [55]:
dict = pickle.load(open('../../data/train_embeddings.pkl', 'rb'))

In [56]:
train_img_data = []
train_meta_data = []
train_labels = []

for i in dict.keys():
    train_img_data.append(torch.from_numpy(dict[i][0]))
    train_meta_data.append(torch.from_numpy(dict[i][1]))
    train_labels.append(torch.from_numpy(dict[i][2]))

In [57]:
train_img_data = torch.cat(train_img_data, 0).detach()
train_meta_data = torch.cat(train_meta_data, 0)
train_cat_data = torch.cat([train_img_data, train_meta_data], 1)
train_labels = torch.cat(train_labels, 0)

In [59]:
from sklearn.neighbors import KNeighborsClassifier
for train_data in [train_img_data, train_cat_data]:
    classifier = KNeighborsClassifier(n_neighbors=5)
    classifier.fit(train_data.numpy(), train_labels.numpy())
    preds = classifier.predict(train_data.numpy())
    acc = (preds == train_labels.numpy()).mean()
    print(acc)

0.7724891544569044
0.7731235164115576


In [7]:
dotenvpath = find_dotenv()
load_dotenv(dotenvpath)

annotation_path = "../../data/annotations/"
path = '/mnt/f/MetalabelIntegration/'

In [9]:
train_data = UseMetaData(
        "train", path, annotation_path, transform=ValTransforms()
    )
val_data = UseMetaData("val", path, annotation_path, transform=ValTransforms())
    
number_of_classes = len(train_data.classes)

train_loader = DataLoader(
        train_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
        shuffle=True,
    )

val_loader = DataLoader(
        val_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
        shuffle=True,
    )

In [10]:
model = torch.hub.load(
                "pytorch/vision:v0.9.0",
                "resnet50",
                weights="ResNet50_Weights.IMAGENET1K_V1",
            )
model.fc = torch.nn.Identity()
model.eval()

Using cache found in /home/juliu/.cache/torch/hub/pytorch_vision_v0.9.0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
train_img_data = []
train_meta_data = []
train_labels = []

for i, batch in enumerate(tqdm.tqdm(train_loader)):
    train_img_data.append(model(batch[0]))
    train_meta_data.append(batch[1])
    train_labels.append(batch[2])
    if i > 100:
        break

train_img_data = torch.cat(train_img_data, 0).detach()
train_meta_data = torch.cat(train_meta_data, 0)
train_cat_data = torch.cat([train_img_data, train_meta_data], 1)
train_labels = torch.cat(train_labels, 0)