# Состав команды
### Ячменьков Алексей:
#### 
#####
### Баранов Владислав:
#### 
#####
### Аллабердин Богдан:
#### 

In [14]:
import os
from pathlib import Path

import os
import cv2
import wandb
import numpy as np
import matplotlib.pyplot as plt
import shutil
from tqdm.notebook import tnrange

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
from PIL import Image

from facenet_pytorch import MTCNN, InceptionResnetV1

# create train test

In [2]:
# 1) загрузить zip и txt (название картинки и label)
# 2) каждую кратинку обрезать и подогнать под размер
# 3) сохранить в отдельную папку все эти штуки
# 4) read_images с разделением на train и test

In [None]:
def preprocessing_source_photos(source_folder):
    mtcnn = MTCNN(image_size=160)
    for image_name in os.listdir(source_folder):
        image_path = os.path.join(source_folder, image_name)
        image = Image.open(image_path)
        image = mtcnn(image)
        
        if image is not None:
            image.show()

In [11]:
def read_txt_file_to_dict(file_path):
    result_dict = {}
    
    with open(file_path, 'r') as file:
        for line in file:
            photos_name, index = line.strip().split()
                     
            if index not in result_dict:
                result_dict[index] = []
            result_dict[index].append(photos_name)
            
    return result_dict

In [None]:
def make_train_and_test_folders(index_to_photos_name_dict, source_folder, train_folder='train', test_folder='test', test_size=0.2):
    output_train_directory = os.path.join(os.getcwd(), train_folder)
    output_test_directory = os.path.join(os.getcwd(), test_folder)
    
    os.makedirs(output_train_directory, exist_ok=True)
    os.makedirs(output_test_directory, exist_ok=True)
    
    indexes = list(index_to_photos_name_dict.keys())
    train_indexes, test_indexes = train_test_split(indexes, test_size=test_size, shuffle=True)
    
    def copy_files(indexes, destination_folder):
        for index in indexes:
            photos_name = index_to_photos_name_dict[index]
            index_folder = os.path.join(destination_folder, index)
            os.makedirs(index_folder, exist_ok=True)
            for photo_name in photos_name:
                source_path = os.path.join(source_folder, photo_name)
                destination_path = os.path.join(index_folder, photo_name)
                shutil.copy(source_path, destination_path)
                
    copy_files(train_indexes, output_train_directory)
    copy_files(test_indexes, output_test_directory)

In [12]:
txt_file_path = os.path.join(os.getcwd(), 'identity_CelebA.txt')

index_to_photos_name_dict = read_txt_file_to_dict(txt_file_path)

In [13]:
index_to_photos_name_dict

{'2880': ['000001.jpg',
  '000404.jpg',
  '003415.jpg',
  '004390.jpg',
  '018062.jpg',
  '025244.jpg',
  '027771.jpg',
  '039393.jpg',
  '047978.jpg',
  '049142.jpg',
  '052385.jpg',
  '052623.jpg',
  '053184.jpg',
  '053311.jpg',
  '055834.jpg',
  '058188.jpg',
  '061431.jpg',
  '068154.jpg',
  '084705.jpg',
  '090937.jpg',
  '096324.jpg',
  '100990.jpg',
  '103728.jpg',
  '108341.jpg',
  '110376.jpg',
  '122439.jpg',
  '131731.jpg',
  '134007.jpg',
  '139106.jpg',
  '140935.jpg'],
 '2937': ['000002.jpg',
  '011437.jpg',
  '016335.jpg',
  '017121.jpg',
  '024291.jpg',
  '037082.jpg',
  '045318.jpg',
  '046844.jpg',
  '048360.jpg',
  '055891.jpg',
  '057357.jpg',
  '058209.jpg',
  '058400.jpg',
  '059638.jpg',
  '060925.jpg',
  '063242.jpg',
  '063616.jpg',
  '066809.jpg',
  '077346.jpg',
  '095867.jpg',
  '099877.jpg',
  '105287.jpg',
  '108661.jpg',
  '114336.jpg',
  '114625.jpg',
  '117710.jpg',
  '120280.jpg',
  '125140.jpg',
  '142601.jpg',
  '152380.jpg'],
 '8692': ['000003.jpg'

In [None]:
def extract_sample(n_way, n_support, n_query, datax, datay):
    sample = []
    K = np.random.choice(np.unique(datay), n_way, replace=False)
    for cls in K:
        datax_cls = datax[datay == cls]
        perm = np.random.permutation(datax_cls)
        sample_cls = perm[:(n_support + n_query)]
        sample.append([cv2.resize(cv2.imread(fname), (28, 28))
                                  for fname in sample_cls])

    sample = np.array(sample)
    sample = torch.from_numpy(sample).float()
    sample = sample.permute(0, 1, 4, 2, 3)
    return ({
        'images': sample,
        'targets': K,
        'n_way': n_way,
        'n_support': n_support,
        'n_query': n_query
    })

In [None]:
def display_sample(sample):
    sample_4D = sample.view(sample.shape[0] * sample.shape[1], *sample.shape[2:])
    out = torchvision.utils.make_grid(sample_4D, nrow=sample.shape[1])

    plt.figure(figsize=(16, 7))
    plt.imshow(out.permute(1, 2, 0))

In [None]:
sample_example = extract_sample(8, 5, 5, trainx, trainy)
display_sample(sample_example['images'])
print(sample_example['targets'])

In [None]:
sample_example['images'].shape

## Build model

In [None]:
class ProtoNet(nn.Module):
    def __init__(self, device='cuda'):
        super(ProtoNet, self).__init__()
        self.device = device
        self.encoder = InceptionResnetV1(pretrained='vggface2').to(self.device)

    def set_forward_loss(self, sample):
        sample_images = sample['images']
        n_way = sample['n_way']
        n_support = sample['n_support']
        n_query = sample['n_query']

        sample_images = sample_images.to(self.device)
        sample_images = sample_images.view(n_way * (n_support + n_query), *sample_images.shape[-3:])

        # img2vec results
        vectors = self.encoder(sample_images).view(n_way, (n_support + n_query), -1)

        # центры тяжести класстеров классов
        prototypes = vectors[:, :n_support].mean(1)

        # ищем расстояния от каждого изображения из query до каждого центра тяжести
        queries = vectors[:, n_support:].contiguous().view(-1, vectors.shape[-1])
        query_dists = list()
        for query in queries:
            prototype_dists = torch.stack([torch.sqrt(torch.pow(query - prototype, 2).sum()) for prototype in prototypes])
            query_dists.append(prototype_dists)
        query_dists = torch.stack(query_dists)

        probabilities = F.log_softmax(-query_dists, dim=1).view(n_way, n_query, -1)

        _, y_hat = probabilities.max(2)

        losses = list()
        n_true_positive = 0
        for way in range(n_way):
            for query in range(n_query):
                losses.append(-probabilities[way][query][way])
                if y_hat[way][query] == way:
                    n_true_positive += 1
        loss_val = torch.stack(losses).mean()
        acc = n_true_positive / (n_way * n_query)

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc,
            'y_hat': y_hat
            }

## Train

In [None]:
def train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size):
    if wandb.run is not None:
        wandb.watch(model)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.5, last_epoch=-1)
    epoch = 0 
    stop = False 

    train_losses = []
    train_acc = []
    while epoch < max_epoch and not stop:
        running_loss = 0.0
        running_acc = 0.0

        for episode in tnrange(epoch_size, desc="Epoch {:d} train".format(epoch + 1)):
            sample = extract_sample(n_way, n_support, n_query, train_x, train_y)
            optimizer.zero_grad()
            loss, output = model.set_forward_loss(sample)
            running_loss += output['loss']
            running_acc += output['acc']
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        
        train_losses.append(epoch_loss)
        train_acc.append(epoch_acc)
        
        if wandb.run is not None:
            wandb.log({'Epoch': epoch + 1, 'Loss': round(epoch_loss, 4), 'Acc': round(epoch_acc, 4)})
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1
        scheduler.step()
    torch.save(model.state_dict(), "new_model.pt")
    return train_losses, train_acc

In [None]:
should_log = False

model = load_protonet_conv(
    x_dim=(3, 28, 28),
    hid_dim=64,
    z_dim=64,
    device='cuda'
)

optimizer = optim.Adam(model.parameters(), lr=0.001)

n_way = 60
n_support = 5
n_query = 5

train_x = trainx
train_y = trainy

max_epoch = 3
epoch_size = 1000

In [None]:
from show_train_res import plot_losses

In [None]:
%%time
if should_log:
    wandb.init(project='homework', name='baseline')
loss, acc = train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size)

## Графики обучения ProtoNet

In [None]:
plot_losses(np.array([loss, acc]), np.array([[],[]]), 'loss', 'acc')

## Test

In [None]:
def test(model, test_x, test_y, n_way, n_support, n_query, test_episode):
    if wandb.run is not None:
        wandb.watch(model)
    running_loss = 0.0
    running_acc = 0.0
    for episode in tnrange(test_episode):
        sample = extract_sample(n_way, n_support, n_query, test_x, test_y)
        loss, output = model.set_forward_loss(sample)
        running_loss += output['loss']
        running_acc += output['acc']

    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    if wandb.run is not None:
        wandb.log({'n_way': n_way, 'k_shot': n_support})
        wandb.log({'Test loss': round(avg_loss, 4), 'Test acc': round(avg_acc, 4), 'n_way': n_way, 'k_shot': n_support})
    return [avg_loss, avg_acc]

In [None]:
n_way = 5
n_support = 5
n_query = 5

test_x = testx
test_y = testy

test_episode = 1000

model = load_protonet_conv(
    x_dim=(3, 28, 28),
    hid_dim=64,
    z_dim=64,
    device='cuda'
)
model.load_state_dict(torch.load('./new_model.pt'))

In [None]:
print('test: 5-way 1-shot')
output = test(model, test_x, test_y, 5, 1, n_query, test_episode)
print(f'loss: {output[0]}, acc: {output[1]}')
if wandb.run is not None:
    wandb.finish()

## Prediction on our images

In [None]:
def predict(model, sample):
    with torch.inference_mode():
        output = model.set_forward_loss(sample)[1]['y_hat'].to('cpu').detach().numpy()
        preds = sample['targets'][output]
        print('predict : target')
        for pred, target in zip(preds, sample['targets']):
            print(f'{pred[0]} : {target}')