# Model training

In [1]:
import torch
from torch import nn

In [2]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
# Load data
import pandas as pd
X = pd.read_csv("./../data/features/in_features.csv")
X = torch.from_numpy(X.to_numpy()).type(torch.float)
X.shape

torch.Size([84, 9897])

In [4]:
X.dtype

torch.float32

In [5]:
y = pd.read_csv("./../data/features/out_features.csv")
y = torch.from_numpy(y.to_numpy()).type(torch.float)
y.shape

torch.Size([84, 13])

In [6]:
in_features, out_features = X.shape[1], y.shape[1]
in_features, out_features

(9897, 13)

## Linear model

In [7]:
from modules.model_builder import BaseModel

model_0 = BaseModel(in_features, 10, out_features).to(device)

In [8]:
from sklearn.model_selection import train_test_split

RANDOM_SEED = 42
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)

In [9]:
X_train.shape, X_test.shape, y_train.shape, y_test.shape

(torch.Size([67, 9897]),
 torch.Size([17, 9897]),
 torch.Size([67, 13]),
 torch.Size([17, 13]))

In [10]:
from modules.engine import train

loss_fn = nn.CrossEntropyLoss()
optimizer_class = torch.optim.SGD

train(model_0, X_train, X_test, y_train, y_test,
      loss_fn, optimizer_class, 100, 0.1, device)

In [11]:
model_0.eval()
with torch.inference_mode():
    y_logits = model_0(X_test.to(device))
    y_preds = y_logits.argmax(dim=1).to("cpu")
y_preds

tensor([0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0])

## Convolution Model

In [12]:
X_train = X_train.unsqueeze(1)
X_test = X_test.unsqueeze(1)

In [13]:
in_channels, in_length = X_train.shape[1:]
in_channels, in_length

(1, 9897)

In [14]:
from modules.model_builder import ConvModel

model_1 = ConvModel(in_channels, in_length, 10, out_features).to(device)

2475


In [15]:
train(model_1, X_train, X_test, y_train, y_test,
      loss_fn, optimizer_class, 100, 0.1, device)

In [16]:
from modules.model_builder import out_shape_calc

out_shape_calc(9897, kernel_size=3, stride=1, padding=1)

9897