This script reproduces what the main.py function does but divided into its parts so we can visualize the whole process

In [4]:
import os
import torch
import torch.nn as nn
import numpy as np
import argparse
import pickle
from ucimlrepo import fetch_ucirepo 

In [5]:
from src.dataset import Data_handling
from src.weakener import Weakener
from src.model import MLP
from utils.datasets_generation import generate_dataset
import utils.losses as losses
from utils.train_test_loop import train_and_evaluate

In [6]:
reps = 10
dataset_base_path = 'Datasets/weak_datasets'
dataset = 'mnist'
corruption = 'Noisy_Patrini_MNIST'
corr_p = 0.7
corr_n = None
loss_type = 'Forward'

for i in range(reps):
        generate_dataset(dataset=dataset,corruption=corruption,corr_p=corr_p,repetitions=i)


In [11]:


for i in range(reps):
    base_dir = "Datasets/weak_datasets"
    if corr_n is not None:
        folder_path = os.path.join(base_dir, f'{dataset}_{corruption}_p_+{corr_p}p_-{corr_n}')
    else:
        folder_path = os.path.join(base_dir, f'{dataset}_{corruption}_p{corr_p}')
    f = open(folder_path + f'/Dataset_{i}.pkl','rb')
    Data,Weak = pickle.load(f)
    f.close()

    
    
    if loss_type == 'Backward':
        loss_fn = losses.FwdBwdLoss(Weak.Y,np.eye(Weak.c))
    elif loss_type == 'Forward':
        loss_fn = losses.FwdBwdLoss(np.eye(Weak.d),Weak.M)
    elif loss_type == 'EM':
        loss_fn = losses.EMLoss(Weak.M)
    elif loss_type == 'LBL':
        loss_fn = losses.LBLoss()
    elif loss_type == 'Backward_opt':
        loss_fn = losses.FwdBwdLoss(Weak.Y_opt,np.eye(Weak.c))
    elif loss_type == 'Backward_conv':
        loss_fn = losses.FwdBwdLoss(Weak.Y_conv,np.eye(Weak.c))
    elif loss_type == 'Backward_opt_conv':
        loss_fn = losses.FwdBwdLoss(Weak.Y_opt_conv,np.eye(Weak.c))
    elif loss_type == 'OSL':
        loss_fn = losses.OSLCELoss()
    
    if loss_type == 'OSL':
        Data.include_weak(Weak.w)
    else:
        Data.include_weak(Weak.z)

    
    trainloader,testloader = Data.get_dataloader(weak_labels = 'weak')

    print(Data.num_features)


    #trainloader,testloader = Data.get_dataloader()
    mlp = MLP(Data.num_features, [Data.num_features], Weak.c, dropout_p=0.3, bn = True, activation='tanh')
    optim = torch.optim.Adam(mlp.parameters(), lr=1e-3)
    mlp,results = train_and_evaluate(mlp,trainloader,testloader,optimizer=optim,loss_fn=loss_fn,corr_p=corr_p,num_epochs=100,sound=10,rep=i)







# Bwd = FwdBwdLoss(pinv(M),I_c)
# Fwd = FwdBwdLoss(I_d,M)

784
Epoch 10/100: Train Loss: 0.2252, Train Acc: 0.9746, Test Acc: 0.9689, Train Detached Loss: 0.0017, Test Detached Loss: 0.0036, Learning Rate: 0.001000
Epoch 20/100: Train Loss: 0.1899, Train Acc: 0.9860, Test Acc: 0.9716, Train Detached Loss: 0.0012, Test Detached Loss: 0.0042, Learning Rate: 0.001000


KeyboardInterrupt: 

In [None]:

res_dir = f"Results/{dataset}_{corruption}"
os.makedirs(folder_path, exist_ok=True)
if corr_n is not None:
    file_name = f'{loss_type}_p_+{corr_p}p_-{corr_n}.csv'
else:
    file_name = f'{loss_type}_p_+{corr_p}p_-{corr_n}.csv'
file_path = os.path.join(res_dir, file_name)
results.to_csv(file_path, index=False)

print(f'DataFrame saved as CSV at: {file_path}')


DataFrame saved as CSV at: Results\image_pll_p0.5\Forward.csv


In [None]:
results

Unnamed: 0,epoch,train_loss,train_acc,test_acc,train_detached_loss,test_detached_loss,optimizer,loss_fn,repetition,initial_lr,actual_lr,corr_p
0,1,3.469290,0.196429,0.142857,0.066955,0.090892,Adam,FwdBwdLoss,9,0.001,0.001,0.5
1,2,3.441603,0.285714,0.214286,0.065961,0.089786,Adam,FwdBwdLoss,9,0.001,0.001,0.5
2,3,3.413559,0.285714,0.261905,0.065322,0.088971,Adam,FwdBwdLoss,9,0.001,0.001,0.5
3,4,3.388708,0.285714,0.285714,0.064642,0.088314,Adam,FwdBwdLoss,9,0.001,0.001,0.5
4,5,3.371164,0.285714,0.285714,0.063951,0.087665,Adam,FwdBwdLoss,9,0.001,0.001,0.5
...,...,...,...,...,...,...,...,...,...,...,...,...
95,96,2.707886,0.583333,0.500000,0.035038,0.062313,Adam,FwdBwdLoss,9,0.001,0.001,0.5
96,97,2.713032,0.607143,0.500000,0.034726,0.061899,Adam,FwdBwdLoss,9,0.001,0.001,0.5
97,98,2.714207,0.595238,0.547619,0.034487,0.061459,Adam,FwdBwdLoss,9,0.001,0.001,0.5
98,99,2.714451,0.601190,0.547619,0.034288,0.061044,Adam,FwdBwdLoss,9,0.001,0.001,0.5
