In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
import os

In [None]:
from utils import *
from classFNN import FNN

---

#### Loading Data

In [None]:
data_path = '../Indian Pines dataset/indianpinearray.npy'
gt_path = '../Indian Pines dataset/IPgt.npy'

(X_train, y_train, X_test, y_test) = load_data(data_path, gt_path)

---

#### Training

In [None]:
mode = 'default_FNN'

args = {
    'batch': 256,
    'epochs': 100,
    'lr': 1e-3,
    'l1_lambda': 0,
    'l2_lambda': 0,
    'dropout': 0,
    'bn': False,
    'k_folds': 10
    }

In [None]:
kf = KFold(n_splits=args['k_folds'], shuffle=True, random_state=42)

In [None]:
test_dataset = TensorDataset(
    torch.FloatTensor(X_test),
    torch.LongTensor(y_test))

test_loader = DataLoader(
    test_dataset, 
    batch_size=args['batch'], 
    shuffle=False)

In [None]:
models_to_train = load_models_architectures(
    f'../architectures/models_{mode}.txt')

print(f'Models to train: {len(models_to_train)}')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
log_res = pd.DataFrame(columns=[
        'Model', 'Accuracy', 'F1_score'])

for params in models_to_train:
    model = FNN(params, args['dropout'], args['bn']) 
    model.to(device)

    model_name = '-'.join(map(str, params))
    path = f'../runs/{mode}/{model_name}'
    
    train_losses, test_accuracies = [], []
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args['lr'])

    args['optimizer'] = optimizer
    args['model'] = model
    save_params(path, model_name, args)
    log_train = pd.DataFrame(columns=[
        'epochs', 'train_loss', 'val_loss',
        'train_acc', 'val_acc'
    ])

    for epoch in range(args['epochs']):
        current_fold = epoch % 5  
        train_loss = 0.0; val_loss = 0.0
        train_correct = 0; val_correct = 0
        total_samples = 0
        for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
            if fold != current_fold:
                continue
                
            train_dataset = TensorDataset(
                torch.FloatTensor(X_train[train_idx]), 
                torch.LongTensor(y_train[train_idx]))
            
            val_dataset = TensorDataset(
                torch.FloatTensor(X_train[val_idx]), 
                torch.LongTensor(y_train[val_idx]))
    
            train_loader = DataLoader(
                train_dataset, 
                batch_size=args['batch'], 
                shuffle=True)
            
            val_loader = DataLoader(
                val_dataset, 
                batch_size=args['batch'],
                shuffle=False)

            model.train()
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                
                if 'Lasso' in mode:
                    loss = lasso_loss(
                        model, criterion, 
                        outputs, labels, args['l1_lambda'])
                elif 'Ridge' in mode:
                    loss = ridge_loss(
                        model, criterion, 
                        outputs, labels, args['l2_lambda'])
                elif 'Elastic_net' in mode:
                    loss = elastic_net_loss(
                        model, criterion,
                        outputs, labels, 
                        args['l1_lambda'],
                        args['l2_lambda'])
                else:
                    loss = criterion(outputs, labels)    
                    
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)
                
            train_loss /= len(train_loader)
            train_acc = train_correct / total_samples
            total_samples = 0
            model.eval()
            
            with torch.no_grad(): 
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_correct += (predicted == labels).sum().item()
                    total_samples += labels.size(0)
                    
                val_loss /= len(val_loader)
                val_acc = val_correct / total_samples
                
            log_train.loc[len(log_train)] = (
                [int(epoch+1), train_loss, 
                 val_loss, train_acc, val_acc])

    save_res(data=log_train, path=path, 
             rewrite=True, file_name='training')
        
    test_accuracy, test_f1 = evaluate_model(model, test_loader)
    print(f'{model_name}: Accuracy: {test_accuracy:.3f}, F1-score: {test_f1:.3f}')
    
    df = pd.DataFrame(data=[[test_accuracy, test_f1]], 
                      columns=['Accuracy', 'F1_score'])
    save_res(data=df, path=path)
    
    torch.save(model, os.path.join(path, 'model.pth'))

    log_res.loc[len(log_res)] = (
        [model_name, test_accuracy, test_f1])

    
save_res(data=log_res, path=f'../results/{mode}', rewrite=False)

print('=========')
print('Training completed')

---

#### Getting Results

In [None]:
plot_res(path_csv=f'../results/FNN/{mode}/results.csv', 
         path_png=f'../results/FNN/{mode}')

print('=========')
print('Plot saved')

In [None]:
series = []

with open(f'../architectures/models_{mode}.txt', 'r') as f:
    for model in f:
        model = model.rstrip('\n')
        csv_file = f'../runs/{mode}/{model}/training.csv'
        plot_train_val_loss(csv_file, 
                            path_png=f'../results/{mode}/train_val_loss',
                            name_png=model,
                            model_name=model):
        
print('=========')
print('Plots saved')