## Clasificación de los dígitos de Mnist en forma de set de puntos

In [1]:
!pwd

/Users/ahmedbegga/Desktop/TFG-Ahmed/SetXAI/src


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import chamfer_loss
from fspool import FSPool
from model import FSEncoder
from MnistSet import MNISTSet
from MnistSet import get_loader

### Preparamos los datos de entreno y test

In [14]:
batch_size = 32

In [3]:
train_loader = get_loader(
            MNISTSet(train=True, full=True), batch_size=batch_size, num_workers=4)

In [4]:
test_loader = get_loader(
            MNISTSet(train=False, full=True), batch_size=batch_size, num_workers=4)

In [13]:
from time import sleep
from tqdm import tqdm

In [24]:
set_channels = 2
set_size = 342
  
hidden_dim = 256
iters = 10
latent_dim = 64
lr = 0.01
n_epochs = 10
net = FSEncoder(set_channels,latent_dim,hidden_dim)

In [25]:
optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=lr
    )

In [26]:
net

FSEncoder(
  (conv): Sequential(
    (0): Conv1d(3, 256, kernel_size=(1,), stride=(1,))
    (1): ReLU()
    (2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (3): ReLU()
    (4): Conv1d(256, 64, kernel_size=(1,), stride=(1,))
  )
  (pool): FSPool()
  (salida): Linear(in_features=64, out_features=10, bias=True)
)

In [27]:
net.train()
for epoch in range(n_epochs):
    with tqdm(train_loader, unit="batch") as tepoch:
        for i, sample in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch}")
            input, target_set, target_mask = map(lambda x: x, sample)
            optimizer.zero_grad()
            output= net(target_set,target_mask)
            loss = F.nll_loss(output, input)
            pred = output.data.max(1, keepdim=True)[1]
            correct = pred.eq(input.data.view_as(pred)).sum()
            accuracy =correct / batch_size
            loss.backward()
            optimizer.step()
            tepoch.set_postfix(loss=loss.item(), acc=100. * accuracy.item())

Epoch 0: 100%|█████| 1875/1875 [09:48<00:00,  3.19batch/s, acc=90.6, loss=0.227]
Epoch 1: 100%|█████| 1875/1875 [09:49<00:00,  3.18batch/s, acc=100, loss=0.0648]
Epoch 2: 100%|████| 1875/1875 [2:02:44<00:00,  3.93s/batch, acc=90.6, loss=0.22]
Epoch 3: 100%|█████| 1875/1875 [10:26<00:00,  2.99batch/s, acc=93.8, loss=0.169]
Epoch 4: 100%|█████| 1875/1875 [10:15<00:00,  3.04batch/s, acc=90.6, loss=0.152]
Epoch 5: 100%|█████| 1875/1875 [10:07<00:00,  3.09batch/s, acc=90.6, loss=0.154]
Epoch 6: 100%|█████| 1875/1875 [10:23<00:00,  3.01batch/s, acc=93.8, loss=0.284]
Epoch 7: 100%|█████| 1875/1875 [12:39<00:00,  2.47batch/s, acc=93.8, loss=0.187]
Epoch 8: 100%|█████| 1875/1875 [10:34<00:00,  2.95batch/s, acc=96.9, loss=0.192]
Epoch 9: 100%|█████| 1875/1875 [10:48<00:00,  2.89batch/s, acc=90.6, loss=0.166]


In [28]:
net.eval()
with tqdm(test_loader, unit="batch") as tepoch:
    for i, sample in enumerate(tepoch):
        tepoch.set_description(f"Epoch {epoch}")
        input, target_set, target_mask = map(lambda x: x, sample)
        optimizer.zero_grad()
        output= net(target_set,target_mask)
        loss = F.nll_loss(output, input)
        pred = output.data.max(1, keepdim=True)[1]
        correct = pred.eq(input.data.view_as(pred)).sum()
        accuracy =correct / batch_size
        #loss.backward()
        #optimizer.step()
        tepoch.set_postfix(loss=loss.item(), acc=100. * accuracy.item())

Epoch 9: 100%|███████| 312/312 [01:23<00:00,  3.73batch/s, acc=87.5, loss=0.325]
