In [1]:
import torch
import torch.nn as nn
from src.utils import plot_images
from src.preprocess import preprocess
from src.train import trainNN, evaluate_model
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import pandas as pd

In [6]:
df = pd.read_csv('data/xyz_dataset.csv')
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
df_valid, df_test = train_test_split(df_test, test_size=0.5, random_state=42, stratify=df_test['label'])
df_train["label"].value_counts(normalize=True), df_test["label"].value_counts(
    normalize=True
), df_valid["label"].value_counts(normalize=True)

(label
 1.0    0.597455
 2.0    0.334319
 0.0    0.068226
 Name: proportion, dtype: float64,
 label
 1.0    0.597360
 2.0    0.334433
 0.0    0.068207
 Name: proportion, dtype: float64,
 label
 1.0    0.597688
 2.0    0.334067
 0.0    0.068244
 Name: proportion, dtype: float64)

In [7]:
train_loader = preprocess(df_train, batch_size=256)
valid_loader = preprocess(df_valid, batch_size=256)
test_loader = preprocess(df_test, batch_size=256)

model = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Conv2d(32, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Flatten(),
    nn.Linear(32 * 5 * 5, 128),
    nn.ReLU(),
    nn.Linear(128, 3),
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
model, train_loss, train_acc, test_loss, test_acc = trainNN(
    model, train_loader, valid_loader, criterion, optimizer, num_epochs, device, log_train=True, log_test=True
)

Epoch [1/10], Loss: 0.6236, Accuracy: 92.43%
Test Loss: 0.0611, Test Accuracy: 99.01%
Epoch [2/10], Loss: 0.0166, Accuracy: 99.48%
Test Loss: 0.0188, Test Accuracy: 99.28%
Epoch [3/10], Loss: 0.0081, Accuracy: 99.79%
Test Loss: 0.0348, Test Accuracy: 99.67%
Epoch [4/10], Loss: 0.0048, Accuracy: 99.88%
Test Loss: 0.0184, Test Accuracy: 99.56%
Epoch [5/10], Loss: 0.0023, Accuracy: 99.97%
Test Loss: 0.0147, Test Accuracy: 99.56%
Epoch [6/10], Loss: 0.0015, Accuracy: 99.99%
Test Loss: 0.0291, Test Accuracy: 99.61%
Epoch [7/10], Loss: 0.0027, Accuracy: 99.92%
Test Loss: 0.0249, Test Accuracy: 99.67%
Epoch [8/10], Loss: 0.0033, Accuracy: 99.90%
Test Loss: 0.0227, Test Accuracy: 99.50%
Epoch [9/10], Loss: 0.0010, Accuracy: 99.97%
Test Loss: 0.0126, Test Accuracy: 99.67%
Epoch [10/10], Loss: 0.0004, Accuracy: 100.00%
Test Loss: 0.0155, Test Accuracy: 99.67%
