In [1]:
import json
import os

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# from transformers import ViTForImageClassification
# from transformers import ViTImageProcessor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# model.to(device)

In [3]:
from backbone import Linear_fw

class Classifier_FW(nn.Module):
    def __init__(self, input_size=768, num_layers=0, layer_size=5, num_classes=100):
        super().__init__()
        layers = self._generate_layers(input_size, num_layers, layer_size, num_classes)
        self.net = nn.Sequential(*layers)

    def _generate_layers(self, input_size, num_hidden_layers, layer_size, num_classes):
        if num_hidden_layers == 0:
            return [Linear_fw(input_size, num_classes)]

        layers = [Linear_fw(input_size, layer_size), nn.ReLU()]
        for _ in range(num_hidden_layers-1):
            layers.append(Linear_fw(layer_size, layer_size))
            layers.append(nn.ReLU())

        layers.append(Linear_fw(layer_size, num_classes))
        return layers
    
    def forward(self, x):
        return self.net(x)


In [4]:
with open("/home/lukasz/binary-hyper-maml/filelists/miniImagenet/all_vit.json", 'r') as f:
    meta = json.load(f)

In [5]:
class SampleDataset(Dataset):
    def __init__(self, data_file, transform=transforms.ToTensor()):
        with open(data_file, 'r') as f:
            self.meta = json.load(f)

        with open("/home/lukasz/binary-hyper-maml/filelists/omniglot/noLatin_vit.json", 'r') as f:
            self.real = json.load(f)

        self.transform = transform

    def __getitem__(self,i):
        image_path = os.path.join(self.meta["image_names"][i])
        # real_image_path = os.path.join(self.real["image_names"][i])
        emb = np.load(image_path)
        emb = self.transform(emb)
        target = self.meta["image_labels"][i]
        return emb, target

    def __len__(self):
        return len(self.meta["image_names"])

In [6]:
dataset = SampleDataset("/home/lukasz/binary-hyper-maml/filelists/miniImagenet/all_vit.json")

In [7]:
with open("/home/lukasz/binary-hyper-maml/filelists/miniImagenet/all_vit.json", 'r') as f:
    meta = json.load(f)

In [8]:
from torch.utils.data import DataLoader

In [9]:
# data_iter = iter(data_loader)
# batch = next(data_iter)
# embeddings, labels, images = batch

In [10]:
# for embedding, label, image in zip(embeddings, labels, images):
#     with torch.no_grad():
#         logits = model.classifier(embedding.cuda())
#     prediction = logits.argmax(-1)
#     print("Predicted class:", model.config.id2label[prediction.item()])
#     img = Image.open(image)
#     img.show()

In [11]:
import torch


def train(train_dataloader, model, error, optimizer, num_epochs=50):
    train_losses = []

    for epoch in range(num_epochs):
        accuracies = []
        for i, (x, y) in enumerate(train_dataloader):
            model.train()
            optimizer.zero_grad()
            x = x.view(x.shape[0], -1)
            y_pred = model(x)
            loss = error(y_pred, y)

            
            acc = (y == torch.argmax(y_pred, dim=1)).sum()/y.shape[0]
            accuracies.append(acc)

            loss.backward()
            optimizer.step()
        
        train_losses.append(loss.item())
        model.eval()

        print(f'Epoch no.: {epoch+1}, train loss = {train_losses[epoch]:.4f}, accuracy = {np.mean(accuracies):.4f}')

    return train_losses

In [12]:
learning_rate = 0.001

data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
model = Classifier_FW(input_size=768, num_layers=2, layer_size=128, num_classes=100)
# model = Linear_fw(in_features=768, out_features=100)
error = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [13]:
model

Classifier_FW(
  (net): Sequential(
    (0): Linear_fw(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear_fw(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear_fw(in_features=128, out_features=100, bias=True)
  )
)

In [14]:
train(data_loader, model, error, optimizer)

Epoch no.: 1, train loss = 0.1200, accuracy = 0.9423
Epoch no.: 2, train loss = 0.0226, accuracy = 0.9706
Epoch no.: 3, train loss = 0.3743, accuracy = 0.9788
Epoch no.: 4, train loss = 0.0069, accuracy = 0.9830
Epoch no.: 5, train loss = 0.0569, accuracy = 0.9855
Epoch no.: 6, train loss = 0.0044, accuracy = 0.9865
Epoch no.: 7, train loss = 0.0156, accuracy = 0.9882
Epoch no.: 8, train loss = 0.0004, accuracy = 0.9895
Epoch no.: 9, train loss = 0.0001, accuracy = 0.9893
Epoch no.: 10, train loss = 0.0888, accuracy = 0.9909


KeyboardInterrupt: 