In [133]:
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchvision import transforms
from torchvision.models import resnet18
import torchvision.models as models
import pickle
import pandas as pd



### Using GPU

In [134]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")

Training on cuda


### Loading Dataset

In [135]:
class FewShotDataset(Dataset):
    def __init__(self, data_path):
        with open(data_path, 'rb') as f:
            data_dict = pickle.load(f)
        self.images = torch.from_numpy(data_dict['images']).float()
        self.labels = torch.from_numpy(data_dict['labels']).long()
        # self.images = data_dict['images']
        # self.labels = data_dict['labels']
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

class FewShotTestDataset(Dataset):
    def __init__(self, data_path):
        with open(data_path, 'rb') as f:
            data_dict = pickle.load(f)
        self.sup_images = torch.from_numpy(data_dict['sup_images']).float()
        self.sup_labels = torch.from_numpy(data_dict['sup_labels']).long()
        self.qry_images = torch.from_numpy(data_dict['qry_images']).float()
    
    def get_labels(self):
        return self.sup_labels
    
    



In [136]:
train_data = FewShotDataset('train.pkl')
val_data = FewShotDataset('validation.pkl')
test_data = FewShotTestDataset('test.pkl')

In [137]:
# transform = torchvision.transforms.ToPILImage()
# img = transform(torch.tensor(test_data.qry_images[0][2])*255)
# img.show()
test_labels = test_data.get_labels()
for i in test_labels:
    print(i)

tensor([3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2,
        2])
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 3, 3, 3, 3,
        3])
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2,
        2])
tensor([3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 4, 4, 4, 4, 4, 2, 2, 2, 2,
        2])
tensor([2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1])
tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 3, 3, 3, 3,
        3])
tensor([4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3, 3,
        3])
tensor([3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2,
        2])
tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 3, 3, 3, 3,
        3])
tensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 1, 1, 1, 1,
        1])
tensor([2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 3, 3, 3, 3,


### Defind FewshotModel

In [138]:
class FewShotModel(nn.Module):
    def __init__(self):
        super(FewShotModel, self).__init__()
        self.resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, 64)

    def forward(self, x):
        x = self.resnet18.conv1(x)
        x = self.resnet18.bn1(x)
        x = self.resnet18.relu(x)
        x = self.resnet18.maxpool(x)

        x = self.resnet18.layer1(x)
        x = self.resnet18.layer2(x)
        x = self.resnet18.layer3(x)
        x = self.resnet18.layer4(x)

        x = self.resnet18.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.resnet18.fc(x)
        return x

### Model Train Function

In [139]:
def train(model, train_data, val_data, num_epochs=20, batch_size=32, learning_rate=0.001):
    train_dataset = train_data
    val_dataset = val_data
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            train_acc += torch.sum(preds == labels.data)
        train_loss /= len(train_loader.dataset)
        train_acc /= len(train_loader.dataset)
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_acc += torch.sum(preds == labels.data)
        val_loss /= len(val_loader.dataset)
        val_acc /= len(val_loader.dataset)
        print('Epoch: {} \tTraining Loss: {:.6f} \tTraining Accuracy: {:.6f} \tValidation Loss: {:.6f} \tValidation Accuracy: {:.6f}'.format(epoch+1, train_loss, train_acc, val_loss, val_acc))
    torch.save(model.state_dict(), 'model.pth')

### Model Test Function

In [140]:
def test(model, test_data, num_task=600):
    sup_images = test_data.sup_images
    sup_labels = test_data.sup_labels
    qry_images = test_data.qry_images
    num_correct = 0
    pred_arr = []
    for i in range(num_task):
        support_set = sup_images[i].to(device)
        support_labels = sup_labels[i].to(device)
        query_set = qry_images[i].to(device)
        query_labels = np.arange(5)
        query_labels = np.tile(query_labels, 5)
        query_labels = torch.from_numpy(query_labels).long().to(device)
        # print(query_labels)
        # support_set = torch.from_numpy(support_set).float()
        # support_set = support_set.float()
        # support_labels = torch.from_numpy(support_labels).long()
        # query_set = torch.from_numpy(query_set).float()
        model.eval()
        with torch.no_grad():
            support_features = model(support_set)
            query_features = model(query_set)
            # support_features = support_features.unsqueeze(0).repeat(25, 1, 1)
            # support_features = support_features.unsqueeze(0).repeat(25, 1, 1)
            # query_features = query_features.unsqueeze(1).repeat(1, 25, 1)
            # distances = torch.sum((support_features - query_features) ** 2, dim=2)
            # preds = torch.argmin(distances, dim=1)
            dists = torch.cdist(query_features, support_features)
            _, preds = torch.min(dists, dim=1)
            preds = preds.cpu().numpy()
            # num_correct += torch.sum(preds == query_labels)
            # print(support_labels)
            # print(preds)
            
            label_dict = {}
            for j in range(len(support_labels)):
                label_dict[preds[j]] = support_labels[preds[j]]
            
            for j in range(len(preds)):
                preds[j] = label_dict[preds[j]]

            pred_arr.append(preds)
    
    pred_arr = np.array(pred_arr)
    pred_arr = pred_arr.reshape(-1)
    pred_data = {"Category":pred_arr}
    df_pred = pd.DataFrame(pred_data)
    df_pred.to_csv('310581027_pred.csv', index_label='Id')
    
    # pred_data = {"pred":pred_arr}
    # df_pred = pd.DataFrame(pred_data)
    # df_pred.to_csv('310581027_pred.csv', index_label='id')
    # accuracy = num_correct / (num_task * 25)
    # print('Test Accuracy: {:.6f}'.format(accuracy))

### Create Model and Training

In [141]:
model = FewShotModel().to(device)
train(model, train_data, val_data)

Epoch: 1 	Training Loss: 2.412719 	Training Accuracy: 0.371198 	Validation Loss: 7.087693 	Validation Accuracy: 0.014375
Epoch: 2 	Training Loss: 1.728055 	Training Accuracy: 0.526302 	Validation Loss: 7.532542 	Validation Accuracy: 0.011354
Epoch: 3 	Training Loss: 1.365813 	Training Accuracy: 0.615755 	Validation Loss: 7.963136 	Validation Accuracy: 0.013333
Epoch: 4 	Training Loss: 1.061913 	Training Accuracy: 0.693255 	Validation Loss: 8.215373 	Validation Accuracy: 0.013542
Epoch: 5 	Training Loss: 0.791223 	Training Accuracy: 0.763828 	Validation Loss: 9.863120 	Validation Accuracy: 0.013021
Epoch: 6 	Training Loss: 0.569267 	Training Accuracy: 0.827396 	Validation Loss: 10.553090 	Validation Accuracy: 0.010417
Epoch: 7 	Training Loss: 0.427693 	Training Accuracy: 0.867995 	Validation Loss: 11.141880 	Validation Accuracy: 0.013333
Epoch: 8 	Training Loss: 0.346669 	Training Accuracy: 0.889635 	Validation Loss: 11.692587 	Validation Accuracy: 0.017500
Epoch: 9 	Training Loss: 0.28

### Model Testing

In [142]:
test(model, test_data)