In [None]:
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from torchinfo import summary

## Mise en place du data train and test

In [None]:
import os 

batch_size = 10
in_channel = 3
data_dir = ''

classes = sorted(os.listdir(data_dir))
num_classes = len(classes)
num_files = len(os.listdir(data_dir + '/' + classes[0]))


In [None]:
trfs = transforms.Compose(
    [
        transforms.Resize(640),
        transforms.CenterCrop(640),
        transforms.ToTensor(), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    ])

dataset = datasets.ImageFolder(data_dir, transform=trfs)
data_train, data_test = torch.utils.data.random_split(dataset,[num_files,len(dataset)-num_files])

train_loader = torch.utils.data.DataLoader(
    data_train, 
    batch_size=batch_size,
    shuffle=True
    )
test_loader = torch.utils.data.DataLoader(
    data_test, 
    batch_size=batch_size,
    shuffle=True
    )

from tools.utils import display_dataset

display_dataset(data_train, n=num_classes, classes=classes)



In [None]:
from tools.models import EfficientNet, Dense, OneCNN, MultiLayerCNN, LeNet

model = models.efficientnet_v2_s(pretrained=True)

#Transfer learning 
for x in model.parameters():
    x.requires_grad = False

model.classifier = nn.Linear(1280, num_classes)
summary(model, (1, 3, 640, 640))


### Mise en place des modèles

### Training model

In [None]:
num_epochs = 10
learning_rate = 0.001
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

import wandb

wandb.login(key="", host="")
wandb.init(project="efficientnet_v2_s")
#wandb.config({"epochs": num_epochs, "learning_rate": learning_rate, "batch_size": batch_size})
wandb.watch(model)

from tools.utils import train

history = train(model, train_loader, test_loader, optimizer, loss_fn, True, num_epochs)

In [None]:
x_axis = list(range(1, num_epochs+1))

fig = plt.figure(figsize=(15, 5))
ax = plt.subplot()
ax.set(ylim=(0, 1))
plt.plot(x_axis, history['train_acc'], label='train_acc', color='blue')
plt.plot(x_axis, history['val_acc'], label='val_acc', color='orange')
plt.legend(loc='upper left')

fig = plt.figure(figsize=(15, 5))
ax = plt.subplot()
ax.set(ylim=(0, 2))
plt.plot(x_axis, history['train_loss'], label='train_loss', color='blue')
plt.plot(x_axis, history['val_loss'], label='val_loss', color='orange')
plt.legend(loc='upper left')

In [None]:
torch.save(model, 'model.pt')

In [None]:
torch.load('model.pt')