In [1]:
import torch, torchvision
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch_train import TorchTrain
from torchmetrics import Accuracy, Precision, Recall

In [2]:
to_tensor = transforms.ToTensor()
train_data = datasets.FashionMNIST(
    root="data", train=True, download=True, transform=to_tensor
)
test_data = datasets.FashionMNIST(
    root="data", train=False, download=True, transform=to_tensor
)

BATCH_SIZE = 32
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)


class Model0(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        return x


model0 = Model0()

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model0.parameters(), lr=0.001)

acc = Accuracy(task="multiclass", num_classes=10)
precision = Precision(num_classes=10, task="multiclass")
recall = Recall(num_classes=10, task="multiclass")
metrics = {"accuracy": acc, "precision": precision, "recall": recall}

tt = TorchTrain(model0, optimizer, loss, metrics=metrics)


In [3]:
history = tt.fit(train_loader, test_loader, epochs=2, train_steps_per_epoch=300, validation_steps_per_epoch=200)



In [4]:
tt.evaluate(test_loader, recall)



0.8008186900958466

In [4]:
import pandas as pd

pd.DataFrame(history)

Unnamed: 0,train_loss,val_loss,train_accuracy,val_accuracy,train_precision,val_precision,train_recall,val_recall
0,0.594331,0.505593,0.80235,0.829219,0.80235,0.829219,0.80235,0.829219
1,0.460607,0.479129,0.84325,0.834375,0.84325,0.834375,0.84325,0.834375


In [5]:
tt.train_loss, tt.train_metrics, tt.test_loss, tt.test_metrics

(0, {}, 0, {})

In [6]:
tt.train_loss_all, tt.test_loss_all

([0.5943314106782277, 0.4606070199012756],
 [0.5055934737622738, 0.47912943303585054])

In [7]:
tt.train_metrics_all

[{'Accuracy': 0.80235, 'Precision': 0.80235, 'Recall': 0.80235},
 {'Accuracy': 0.84325, 'Precision': 0.84325, 'Recall': 0.84325}]

In [8]:
tt.test_metrics_all

[{'Accuracy': 0.82921875, 'Precision': 0.82921875, 'Recall': 0.82921875},
 {'Accuracy': 0.834375, 'Precision': 0.834375, 'Recall': 0.834375}]

In [9]:
tt.train_metrics_all[0]

{'Accuracy': 0.80235, 'Precision': 0.80235, 'Recall': 0.80235}