In [1]:
import os
import glob
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn.functional as F



from FileReader import get_picture_tensors
from ModelEvaluation import eval_model

from CatNet.datasets import dataset_cat_no_cat
from CatNet.models import CatNet


In [3]:
n_classes = 32

# mécanisme d'attention = 1 sortie (0 c'est background, 1 c'est chat)
model = CatNet(cnn_backbone = 'mobilenet_v2', num_classes = 1)

model.summary()

loss_at_each_epoch = []


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192
            ReLU6-12         [-1, 96, 112, 112]               0
           Conv2d-13           [-1, 96, 56, 56]             864
      BatchNorm2d-14           [-1, 96,

In [None]:

# data params
# cat_directory = "normal_prep_datasets/dataset_chat_downscale_no_background/"
# background_directory = "normal_prep_datasets/dataset_chat_downscale_no_cat"

cat_directory = "dataset_augmented_no_background"
background_directory = "dataset_augmented_no_cat"

required_train_imgs = 10
required_test_imgs = 1

num_epochs = 32
# batches_per_epoch = 8
batches_per_epoch = 32

batch_size = 8
# batch_size = 128

# learning_rate = 1e-5
learning_rate = 1e-4
# learning_rate = 1e-3
# learning_rate = 1e-2
# learning_rate = 1e-6

# ratio = [0.75, 0.25]
ratio = 0.5

model.train()



criterion = nn.BCEWithLogitsLoss()



params = filter(lambda x: x.requires_grad, model.parameters())
optimizer = optim.Adam(params, lr = learning_rate)

# # Gossage sur les params potentiel
# momentum = 0.5
# optimizer = optim.SGD(params, lr = learning_rate, momentum = momentum)   





for epoch in range(num_epochs):

    (train_images_cat, val_images_cat, test_images_cat, 
    train_labels_cat , val_labels_cat, test_labels_cat, n_classes) = get_picture_tensors(root_directory=cat_directory,
                                                                n_classes=n_classes, 
                                                                required_train_imgs=required_train_imgs, 
                                                                required_test_imgs=required_test_imgs,
                                                                use_selected_eval_datasets = False,
                                                                shuffle_directories = True,
                                                                shuffle_images = True, 
                                                                show_progress=False,
                                                                ordered_dataset=True)
    

    (train_images_background, val_images_background, test_images_background, 
    train_labels_background , val_labels_background, test_labels_background7 , n_classes) = get_picture_tensors(root_directory=background_directory,
                                                                n_classes=n_classes, 
                                                                required_train_imgs=required_train_imgs, 
                                                                required_test_imgs=required_test_imgs,
                                                                use_selected_eval_datasets = False,
                                                                shuffle_directories = True,
                                                                shuffle_images = True, 
                                                                show_progress=False,
                                                                ordered_dataset=True)

    train_dataset = dataset_cat_no_cat(train_images_cat, train_images_background)
    train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

    total_loss = 0
    
    for imgs, labels in train_dataloader:
        # # Entrainement de l'auto-encodeur
        optimizer.zero_grad()

        # print(labels)
        # print(imgs.shape)
        # print(imgs.shape)

        output = model(imgs)
        # print(output)

        # output = torch.flatten(torch.sigmoid(output))  # sigmoid ici si BCEloss, sinon pas besoin car inclut dans bcewithlogitloss
        output = torch.flatten(output)


        # print(output)
        loss = criterion(output, labels.float()) 

        total_loss += loss.detach().numpy()

        loss.backward()
        optimizer.step()


    print(f"End of epoch {epoch}")
    print("Total loss in epoch: ", total_loss)

    loss_at_each_epoch.append(total_loss)



    # maxtot = max([max(loss_at_each_epoch), max(loss_at_each_epoch_label0), max(loss_at_each_epoch_label1)])
    # max_autoencodeur = max(loss_at_each_epoch_autoencodeur)
    # max_similarite = max(loss_at_each_epoch_similarite)

    # maxtot = max([max_autoencodeur, max_similarite])
    maxtot = max(loss_at_each_epoch)

    from matplotlib.ticker import MultipleLocator

    fig, ax = plt.subplots(1, 1, figsize=(5, 2.5))

    ax.set_xlabel('Epoch (-)')
    # ax.set_ylabel('Validation accuracy (%)')
    ax.set_ylabel('Loss (-)')
    # ax.set_ylim(0, 100)
    # ax.set_ylim(0, 7500)
    ax.set_ylim(0, 1.05*maxtot)
    # ax.set_yticks(np.arange(0, 110, 10))
    ax.xaxis.set_minor_locator(MultipleLocator(1))

    ax.plot(loss_at_each_epoch, '.-')
    # ax.plot(loss_at_each_epoch_similarite, '.-', label='similarite')
    # ax.plot(loss_at_each_epoch_label0, '.-', label='Diff cats')
    # ax.plot(loss_at_each_epoch_label1, '.-', label='Same cat')

    # ax.legend(loc='upper right')
    # ax.legend(loc='best')

    plt.grid(True)
    plt.show()


In [None]:
all_param = []

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
        print(param.data)
        all_param.append(param.data)
    
print(model.classifier[-1])



In [None]:
# filename = f"{model.cnn_backbone}_attention"

# # maxtot = max([max_autoencodeur, max_similarite])
# maxtot = max(loss_at_each_epoch)

# from matplotlib.ticker import MultipleLocator

fig, ax = plt.subplots(1, 1, figsize=(5, 2.5))

# ax.set_xlabel('Epoch (-)')
# ax.set_ylabel('Validation accuracy (%)')
# ax.set_ylabel('Loss (-)')
# ax.set_ylim(0, 100)
# ax.set_ylim(0, 7500)
# ax.set_ylim(0, 1.05*maxtot)
# ax.set_yticks(np.arange(0, 110, 10))
# ax.xaxis.set_minor_locator(MultipleLocator(1))

print(all_param[0].shape)

y = torch.flatten(all_param[0])

ax.plot(np.arange(0, len(y)), y, '.')
# ax.plot(loss_at_each_epoch_similarite, '.-', label='similarite')
# ax.plot(loss_at_each_epoch_label0, '.-', label='Diff cats')
# ax.plot(loss_at_each_epoch_label1, '.-', label='Same cat')

# ax.legend(loc='upper right')
# ax.legend(loc='best')

# plt.grid(True)

fig.tight_layout()

# fig.savefig(f"{filename}.png", dpi = 300)
# fig.savefig(f"{filename}.svg", dpi = 300)

plt.show()




In [None]:

torch.save(model.classifier[-1].state_dict(), 'mobilenetv2_attentionlayer_augmented.pth')

In [None]:
filename = f"{model.cnn_backbone}_attention"

# maxtot = max([max_autoencodeur, max_similarite])
maxtot = max(loss_at_each_epoch)

from matplotlib.ticker import MultipleLocator

fig, ax = plt.subplots(1, 1, figsize=(5, 2.5))

ax.set_xlabel('Epoch (-)')
# ax.set_ylabel('Validation accuracy (%)')
ax.set_ylabel('Loss (-)')
# ax.set_ylim(0, 100)
# ax.set_ylim(0, 7500)
ax.set_ylim(0, 1.05*maxtot)
# ax.set_yticks(np.arange(0, 110, 10))
ax.xaxis.set_minor_locator(MultipleLocator(1))

ax.plot(loss_at_each_epoch, '.-')
# ax.plot(loss_at_each_epoch_similarite, '.-', label='similarite')
# ax.plot(loss_at_each_epoch_label0, '.-', label='Diff cats')
# ax.plot(loss_at_each_epoch_label1, '.-', label='Same cat')

# ax.legend(loc='upper right')
# ax.legend(loc='best')

plt.grid(True)

fig.tight_layout()

fig.savefig(f"{filename}.png", dpi = 300)
fig.savefig(f"{filename}.svg", dpi = 300)

plt.show()








In [None]:


# model.save_parameters_to_file('mobilenetv2_attention.pth')
model.save_parameters_to_file(f"{filename}.pth")


