In [None]:
from torchvision.models import resnet34 
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from torch.utils.data import SubsetRandomSampler, DataLoader
from torch.optim import Adam
from skimage import transform
from torchvision import transforms
from torchvision.utils import make_grid
from torch import nn 
import numpy as np
import torch
import wandb

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
wandb.init(project='net_web_images')

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev) 

In [None]:
composed = transforms.Compose(
    [transforms.ToTensor(), transforms.CenterCrop(300), transforms.Resize((224, 224)),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
dataset = ImageFolder(root='root', transform=composed)

In [None]:
targets = dataset.targets

In [None]:
train_indexes, test_indexes = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)

In [None]:
train_sampler = SubsetRandomSampler(train_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

In [None]:
batch_size = 32

In [None]:
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

In [None]:
net = resnet34(pretrained=True)

In [None]:
net.fc = nn.Linear(net.fc.in_features, 17)

In [None]:
net = net.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(net.parameters(), lr=0.01)

In [None]:
def log_metrics(y_true_train, y_pred_train, y_true_test, y_pred_test, epoch):
    f_score_train = f1_score(y_true_train, y_pred_train, average='macro')
    precision_train = precision_score(y_true_train, y_pred_train, average='macro')
    recall_train = recall_score(y_true_train, y_pred_train, average='macro')
    accuracy_train = accuracy_score(y_true_train, y_pred_train)
    f_score_test = f1_score(y_true_test, y_pred_test, average='macro')
    precision_test = precision_score(y_true_test, y_pred_test, average='macro')
    recall_test = recall_score(y_true_test, y_pred_test, average='macro')
    accuracy_test = accuracy_score(y_true_test, y_pred_test)
    wandb.log({'train accuracy': accuracy_train, 'train precision': precision_train, 'train recall': recall_train, 'train f_score': f_score_train,
              'test accuracy': accuracy_test, 'test precision': precision_test, 'test recall': recall_test, 'test f_score': f_score_test, 'epoch': epoch})
    return accuracy_train, accuracy_test, precision_train, precision_test, recall_train, recall_test, f_score_train, f_score_test

In [None]:
epochs_n = 10

In [None]:
best_accuracy_train, best_accuracy_test, best_precision_train, best_precision_test, best_recall_train, best_recall_test, best_f_score_train, best_f_score_test = [0] * 8
for epoch in range(epochs_n):
    net.train()
    running_loss = 0.0
    j = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0], data[1]
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        j += 1
        if (i + 1) % 100 == 0:
            print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}, loss: {running_loss / j}')
            running_loss = 0.0
            j = 0
    net.eval()
    with torch.no_grad():
        y_true_train = []
        y_pred_train = []
        for data in train_loader:
            images, labels = data[0], data[1]
            outputs = net(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            y_true_train.extend(labels.tolist())
            y_pred_train.extend(predicted.tolist())

        y_true_test = []
        y_pred_test = []
        for data in test_loader:
            images, labels = data[0], data[1]
            outputs = net(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            y_true_test.extend(labels.tolist())
            y_pred_test.extend(predicted.tolist())
        
        accuracy_train, accuracy_test, precision_train, precision_test, recall_train, recall_test, f_score_train, f_score_test = log_metrics(y_true_train, y_pred_train, y_true_test, y_pred_test, epoch + 1)
        
        if best_accuracy_train < accuracy_train:
            best_accuracy_train = accuracy_train
        if best_accuracy_test < accuracy_test:
            best_accuracy_test = accuracy_test
        if best_recall_train < recall_train:
            best_recall_train = recall_train
        if best_recall_test < recall_test:
            best_recall_test = recall_test
        if best_precision_train < precision_train:
            best_precision_train = precision_train
        if best_precision_test < precision_test:
            best_precision_test = precision_test
        if best_f_score_train < f_score_train:
            best_f_score_train = f_score_train
        if best_f_score_test < f_score_test:
            best_f_score_test = f_score_test
        
print('Finished')

In [None]:
wandb.run.summary["accuracy train"] = best_accuracy_train
wandb.run.summary["accuracy test"] = best_accuracy_test
wandb.run.summary["recall train"] = best_recall_train
wandb.run.summary["recall test"] = best_recall_test
wandb.run.summary["precision train"] = best_precision_train
wandb.run.summary["precision test"] = best_precision_test
wandb.run.summary["f-score train"] = best_f_score_train
wandb.run.summary["f-score test"] = best_f_score_test