In [2]:
import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader 
import torch.nn as nn
from torchinfo import summary 
from tqdm import tqdm
import json

In [3]:
# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class PenultimateOutputsDataset(Dataset):
    def __init__(self, folder_path, device='cpu'):
        with open(f'{folder_path}/penultimate_layer_outputs/data.json') as data:
            self.data_tensors = json.load(data)
        self.file_list = [f'{folder_path}/penultimate_layer_outputs/penultimate_layer_outputs_{index}.pt' for index in range(self.data_tensors['last_file_indice']+1)]
        self.labels = torch.load(f'{folder_path}/penultimate_layer_labels.pt', map_location=device)
        self.length = self.data_tensors['min_size_segmentation']*self.data_tensors['last_file_indice']+torch.load(f'{folder_path}/penultimate_layer_outputs/penultimate_layer_outputs_{self.data_tensors["last_file_indice"]}.pt', map_location=device).shape[0]
        self.device = device
        self.data_cache = {}
        
    def __len__(self):
        return self.length

    def __getitem__(self, index):
        if index < 0 or index >= self.length :
           raise ValueError(f'Index need to be between 0 and {self.length-1}')
        index_file = index//self.data_tensors['min_size_segmentation']
        index_in_file = index%self.data_tensors['min_size_segmentation']
        if index_file not in self.data_cache.keys() :
            file = self.file_list[index_file]
            tensors = torch.load(file, map_location=self.device)
            self.data_cache[index_file] = tensors
        tensor = self.data_cache[index_file][index_in_file]
        label = self.labels[index]
        return tensor, label


In [20]:
outputs_folder = "../data/saved_outputs"
models_folder = "../data/saved_models"

input_size = 2048
num_classes = 1000

batch_size = 256
num_epochs = 50

weight_decay_parameter=1e-4
learning_rate = 0.001
# momentum = 0.9
# lr_decay_step = 30
# decay_rate = 0.1


In [19]:
dataset = PenultimateOutputsDataset(outputs_folder)
train_loader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)

In [16]:
class LogisticRegression(nn.Module): 
    def __init__(self, input_size, num_classes): 
        super(LogisticRegression, self).__init__() 
        self.linear = nn.Linear(input_size, num_classes) 
  
    def forward(self, x): 
        out = self.linear(x) 
        out = nn.functional.softmax(out, dim=1) 
        return out 

In [8]:
# On définit le model
model = LogisticRegression(input_size=input_size, num_classes=num_classes) 
model = model.to(DEVICE)
summary(model, input_size=(batch_size, input_size))

Layer (type:depth-idx)                   Output Shape              Param #
LogisticRegression                       [256, 1000]               --
├─Linear: 1-1                            [256, 1000]               2,049,000
Total params: 2,049,000
Trainable params: 2,049,000
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 524.54
Input size (MB): 2.10
Forward/backward pass size (MB): 2.05
Params size (MB): 8.20
Estimated Total Size (MB): 12.34

In [9]:
# Define the loss function and optimizer 
loss_function = nn.CrossEntropyLoss() 
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay_parameter)

In [10]:
# Train the model
correct_predictions=0
loss_history = []
accuracy_history = []
for epoch in range(num_epochs): 
    for inputs, labels in tqdm(train_loader, desc="Training"):
        
        # Move inputs and labels to the device 
        inputs = inputs.to(DEVICE) 
        labels = labels.to(DEVICE)
        
        # Forward pass 
        outputs = model(inputs) 
        loss = loss_function(outputs, labels)

        _, predicted = torch.max(outputs, 1)

        # Ajout des prédictions correctes au total
        correct_predictions += (predicted == labels).sum().item()
        
        # Backward and optimize 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step()

    torch.save(model.state_dict(), f'{models_folder}/epoch-{epoch}.pt')
    accuracy = 100 * correct_predictions / len(dataset)
    print(f'Accuracy: {accuracy}%')
    correct_predictions=0
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
    loss_history.append(loss.item())
    accuracy_history.append(accuracy)

Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:41<00:00, 120.98it/s]


Accuracy: 0.10607516428381311%
Epoch [1/90], Loss: 6.9076924324035645


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:43<00:00, 115.06it/s]


Accuracy: 8.0913729435741%
Epoch [2/90], Loss: 6.907591819763184


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:47<00:00, 104.46it/s]


Accuracy: 22.22606420552512%
Epoch [3/90], Loss: 6.907517433166504


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:45<00:00, 109.57it/s]


Accuracy: 15.434209591723796%
Epoch [4/90], Loss: 6.907593250274658


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:45<00:00, 109.31it/s]


Accuracy: 0.5156236462537671%
Epoch [5/90], Loss: 6.9074578285217285


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:38<00:00, 129.01it/s]


Accuracy: 1.7240531484185901%
Epoch [6/90], Loss: 6.907451152801514


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 132.75it/s]


Accuracy: 1.3797576740581048%
Epoch [7/90], Loss: 6.907432556152344


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:34<00:00, 146.37it/s]


Accuracy: 0.9549886939017318%
Epoch [8/90], Loss: 6.907428741455078


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 135.02it/s]


Accuracy: 1.0858069244680826%
Epoch [9/90], Loss: 6.907458305358887


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:34<00:00, 147.18it/s]


Accuracy: 1.173851652438753%
Epoch [10/90], Loss: 6.907465934753418


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:33<00:00, 149.12it/s]


Accuracy: 1.2522957584764516%
Epoch [11/90], Loss: 6.908184051513672


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:36<00:00, 138.42it/s]


Accuracy: 1.4746711396718772%
Epoch [12/90], Loss: 6.907748222351074


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:35<00:00, 141.08it/s]


Accuracy: 1.555691022325739%
Epoch [13/90], Loss: 6.9080023765563965


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:34<00:00, 146.93it/s]


Accuracy: 1.5890200106621541%
Epoch [14/90], Loss: 6.908168792724609


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 134.93it/s]


Accuracy: 1.7939893862392646%
Epoch [15/90], Loss: 6.908092975616455


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:36<00:00, 138.04it/s]


Accuracy: 1.7977359704082294%
Epoch [16/90], Loss: 6.908050537109375


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 134.16it/s]


Accuracy: 1.8373873195297725%
Epoch [17/90], Loss: 6.908144950866699


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:41<00:00, 120.07it/s]


Accuracy: 1.8910103054480798%
Epoch [18/90], Loss: 6.908085346221924


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:40<00:00, 123.91it/s]


Accuracy: 2.0841935516603223%
Epoch [19/90], Loss: 6.908231735229492


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:38<00:00, 129.23it/s]


Accuracy: 2.0895012125663555%
Epoch [20/90], Loss: 6.9080963134765625


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:38<00:00, 131.58it/s]


Accuracy: 2.219148635579905%
Epoch [21/90], Loss: 6.908143997192383


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:40<00:00, 124.98it/s]


Accuracy: 2.168491695462028%
Epoch [22/90], Loss: 6.908138751983643


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:36<00:00, 135.56it/s]


Accuracy: 2.2728496753350655%
Epoch [23/90], Loss: 6.908274173736572


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:38<00:00, 131.68it/s]


Accuracy: 2.3408345672344044%
Epoch [24/90], Loss: 6.90805196762085


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:41<00:00, 121.03it/s]


Accuracy: 2.4061656286807263%
Epoch [25/90], Loss: 6.908080101013184


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 132.41it/s]


Accuracy: 2.4608813683149817%
Epoch [26/90], Loss: 6.908261299133301


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:40<00:00, 124.12it/s]


Accuracy: 2.506777024384799%
Epoch [27/90], Loss: 6.908004283905029


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:41<00:00, 119.48it/s]


Accuracy: 2.522856114776606%
Epoch [28/90], Loss: 6.908245086669922


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 135.15it/s]


Accuracy: 2.5494724731436262%
Epoch [29/90], Loss: 6.9082417488098145


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 132.35it/s]


Accuracy: 2.639390493198779%
Epoch [30/90], Loss: 6.908204555511475


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:39<00:00, 125.68it/s]


Accuracy: 2.639234385525072%
Epoch [31/90], Loss: 6.908001899719238


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:39<00:00, 128.22it/s]


Accuracy: 2.742031288661041%
Epoch [32/90], Loss: 6.907899856567383


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:38<00:00, 130.24it/s]


Accuracy: 2.8629366819470063%
Epoch [33/90], Loss: 6.907830715179443


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:40<00:00, 123.78it/s]


Accuracy: 2.949966710038582%
Epoch [34/90], Loss: 6.9077839851379395


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 134.18it/s]


Accuracy: 3.016234417527145%
Epoch [35/90], Loss: 6.907751083374023


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:36<00:00, 138.10it/s]


Accuracy: 3.068530488218944%
Epoch [36/90], Loss: 6.907727241516113


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:40<00:00, 124.28it/s]


Accuracy: 3.1055280068874707%
Epoch [37/90], Loss: 6.907710552215576


Training: 100%|███████████████████████████████████████████████████████████████████| 5005/5005 [00:37<00:00, 133.69it/s]


Accuracy: 3.13144188072281%
Epoch [38/90], Loss: 6.907697677612305


Training:  40%|██████████████████████████▊                                        | 2000/5005 [00:16<00:24, 122.82it/s]


KeyboardInterrupt: 