In [23]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchsummary import summary


import my_utility as mu

In [24]:
# Parametri della rete
epochs = 10
batch_size = 16
learning_rate = 0.01
margin = 1  # Margin for contrastive loss.

In [25]:
# Selezione del device da usare per il training
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [26]:
n = 384 # Dimensione codifica frattale


In [27]:
root_dir = ".\img_celeba_10000"

all_img_dir = root_dir + "\\all_img_celeba"
all_codify_dir = root_dir + "\\all_codify_celeba\codify_celeba_all.csv"


all_set = mu.ImageDataSet(img_dir = all_img_dir, codify_dir = all_codify_dir)



all_loader = DataLoader(all_set, batch_size=batch_size, shuffle=True)

Data loaded.


In [28]:
model = mu.SiameseNeuralNetwork().to(device)
summary(model, input_size=(1, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 1, 128, 128]               2
            Conv2d-2          [-1, 4, 124, 124]             104
         AvgPool2d-3            [-1, 4, 62, 62]               0
            Conv2d-4           [-1, 16, 58, 58]           1,616
         AvgPool2d-5           [-1, 16, 29, 29]               0
       BatchNorm1d-6                [-1, 13456]          26,912
            Linear-7                  [-1, 384]       5,167,488
Total params: 5,196,122
Trainable params: 5,196,122
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.06
Forward/backward pass size (MB): 1.33
Params size (MB): 19.82
Estimated Total Size (MB): 21.21
----------------------------------------------------------------


In [42]:
def train(model, optimizer, loss_fn, train_loader, epochs, device):
    for epoch in range(epochs):
        training_loss = 0.0
        
        # Training loop
        model.train()
        for batch in train_loader:
            inputs, label = batch
            inputs = inputs.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            # Calcolo dell'embedding di output
            output = model(inputs)
            
            # Confronto con l'embedding di input
            loss = loss_fn(output, label)
            #loss.requires_grad = True
            
            
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()

        avg_training_loss = training_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] ----> Training loss: {avg_training_loss:.4f}")
        print(f"Epoch [{epoch+1}/{epochs}] ----> Training loss: {training_loss:.4f}")
    


In [43]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Lr scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=10)

In [44]:
'''distance = mu.custom_canberra_distance
criterion = mu.CustomLoss(distance=distance)'''

criterion = mu.CustomLoss2(distance=mu.custom_canberra_distance_batch)

In [45]:
train(model=model, optimizer=optimizer, loss_fn=criterion, train_loader=all_loader, epochs=epochs, device=device)


Epoch [1/10] ----> Training loss: 0.7103
Epoch [1/10] ----> Training loss: 532.7161
Epoch [2/10] ----> Training loss: 0.6911
Epoch [2/10] ----> Training loss: 518.3403
Epoch [3/10] ----> Training loss: 0.6906
Epoch [3/10] ----> Training loss: 517.9128
Epoch [4/10] ----> Training loss: 0.6905
Epoch [4/10] ----> Training loss: 517.8427
Epoch [5/10] ----> Training loss: 0.6904
Epoch [5/10] ----> Training loss: 517.7838
Epoch [6/10] ----> Training loss: 0.6902
Epoch [6/10] ----> Training loss: 517.6509
Epoch [7/10] ----> Training loss: 0.6902
Epoch [7/10] ----> Training loss: 517.6520
Epoch [8/10] ----> Training loss: 0.6903
Epoch [8/10] ----> Training loss: 517.7546


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), f="model_all_images")