In [25]:
import torchvision
import torch
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets
import pytorch_lightning as pl
import glob
from pandas.core.common import flatten

import matplotlib.pyplot as plt
from PIL import Image

In [26]:
train_dataset_path = '../input/iais22-birds/birds/birds'

#NORMALIZACIÓN: image = (image - mean) /std
mean = [0.4704, 0.4669, 0.3898]
std = [0.2035, 0.2001, 0.2047]
number_of_images = 58388

train_transforms = transforms.Compose ([
    transforms.Resize((64,64)),
    transforms.RandomHorizontalFlip(), #Para que no los datos no estén sesgados, le aplicamos modificaciones
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])

In [27]:
orig_set = torchvision.datasets.ImageFolder(root = train_dataset_path, transform = train_transforms) 

In [28]:
n = len(orig_set)  # total number of examples
n_test = int(0.2 * n)  # take ~10% for test
train_set_size = int(0.8*n)
valid_set_size = n - train_set_size
train_set, valid_set = data.random_split(orig_set, [train_set_size, valid_set_size])
train_loader = torch.utils.data.DataLoader(orig_set, batch_size = 32, shuffle=True,pin_memory=True)
test_loader = torch.utils.data.DataLoader(valid_set, batch_size = 32, shuffle=False,pin_memory=True)

In [29]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [30]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) #Aplicamos Capa Convolucional para realizar transformación lineal a todos los parches.

    def forward(self, x):
        x = self.proj(x)  # (B, E, P, P)
        x = x.flatten(2)  # (B, E, N)
        x = x.transpose(1, 2)  # (B, N, E)
        return x

In [31]:
class MultiHeadAttention(nn.Module):

    def __init__(self, n_embd, n_heads):
        super().__init__()
        self.n_heads = n_heads

         #proyeccion de la key,Query y Value
        self.key = nn.Linear(n_embd, n_embd * n_heads)#se multiplca por el num de cabezas de la red si usamos num de cabezas pequeño como en este caso
        self.query = nn.Linear(n_embd, n_embd * n_heads)
        self.value = nn.Linear(n_embd, n_embd * n_heads)

        self.proj = nn.Linear(n_embd * n_heads, n_embd)

    def forward(self, x):
        B, L, F = x.size()
      

  
        k = self.key(x).view(B, L, F, self.n_heads).transpose(1, 3)
        q = self.query(x).view(B, L, F, self.n_heads).transpose(1, 3) 
        v = self.value(x).view(B, L, F, self.n_heads).transpose(1, 3) 

        #calculamos matriz de atención multiplicando la q por la K.T despues se realiza el escalado y por último la funcion softmax para que todas las filas sumen uno
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = torch.nn.functional.softmax(att, dim=-1)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, L, F*self.n_heads) 

        return self.proj(y)

In [32]:
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)  #aplica normalización
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),                  #aplica transformaciones lineales
            nn.ReLU(),                                      #fn de activación
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))#Primero normalizamos y dsp realizamos la atención
        x = x + self.mlp(self.ln2(x))
        return x

In [33]:
class ViT(nn.Module):

    def __init__(self, img_size=64, patch_size=8, in_chans=3, embed_dim=512, n_heads=4, n_layers=6, n_classes=400):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))  #usamos como primera entrada para clasificador del transformer class_token
        self.pos_embed = nn.Parameter(torch.zeros(1, 1+ self.patch_embed.n_patches, embed_dim))#posicional embending, posicion del parche
        
        self.tranformer = torch.nn.Sequential(*[TransformerBlock(embed_dim, n_heads) for _ in range(n_layers)]) #el bloque de la red neuronal se repite la cantida d de capas que tenga nuestro modelo
        
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = torch.nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        e = self.patch_embed(x)
        B, L, E = e.size()
        
        cls_token = self.cls_token.expand(B, -1, -1)  
        e = torch.cat((cls_token, e), dim=1)  #cat sirve para concatenar tensores
        e = e + self.pos_embed 
        
        z = self.tranformer(e)
        
        cls_token_final = z[:, 0]  #cogemos solo el primer vector que es el que predice
        y = self.fc(cls_token_final)

        return y
    

In [34]:
vit = ViT()
#out = vit(images)
#out.shape


In [35]:
class Model(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.vit = ViT()

    def forward(self, x):
        return self.vit(x)

    def predict(self, x):
        with torch.no_grad():
            y_hat = self(x)    
            return torch.argmax(y_hat, axis=1)
        
    def compute_loss_and_acc(self, batch):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.shape[0]
        return loss, acc
    
    def training_step(self, batch, batch_idx):
        loss, acc = self.compute_loss_and_acc(batch)
        self.log('loss', loss)
        self.log('acc', acc, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self.compute_loss_and_acc(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        return optimizer

In [36]:
model = Model()
#out = model(images)
#out.shape

In [46]:

trainer = pl.Trainer(max_epochs=2, gpus=1, logger=None)
trainer.fit(model, train_loader,test_loader)


In [38]:
trainer.validate(model, dataloaders=train_loader)

In [39]:
trainer.validate(model, dataloaders=test_loader)

In [40]:
test_path = '../input/iais22-birds/submission_test/submission_test'
test_images_paths = []
for data_path in glob.glob(test_path + '/*'):
    test_images_paths.append(glob.glob(data_path))
    
test_images_paths = list(flatten(test_images_paths))


In [41]:
def to_device(data, device):
    "Move data to the device"
    if isinstance(data,(list,tuple)):
        return [to_device(x,device) for x in data]
    return data.to(device,non_blocking = True)

In [42]:
model =model.cuda()

In [43]:
Categories = []
Id = []

for i in range (0, len(test_images_paths)):
    img_name = test_images_paths[i]
    img = Image.open(img_name)
    img = train_transforms(img)
    img = to_device(img.unsqueeze(0), device)
    x = img_name.replace('../input/iais22-birds/submission_test/submission_test/', '')
    x = x.split('.')
    Id.append(x[0])
    class_predict = model.predict(img)
    class_predict=orig_set.classes[class_predict[0].item()]
    Categories.append(class_predict)


In [44]:
import pandas as pd

img_name = []
series_name = []


dict = {"Id": Id, "Category": Categories}

df = pd.DataFrame(dict)

df.to_csv('submit2.csv', index=False)

In [45]:
model_path = 'transformer.pth'
torch.save(model.state_dict(),model_path)

model2 = Model()
model2.load_state_dict(torch.load(model_path))

model2.eval()