In [1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

In [2]:
data_path = "../data/data.npz"
data = np.load(data_path)

data_ai_path = "../data/data_ai.npz"
data_ai = np.load(data_ai_path)

X = np.concatenate([data["data"], data_ai["data"]], axis = 0)
target = np.concatenate([data["target"], data_ai["target"]], axis = 0)

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, target, test_size=0.3)

X_train_tensor = torch.tensor(X_train).view(-1,1,128,128)
y_train_tensor = torch.tensor(y_train)

X_test_tensor = torch.tensor(X_test).view(-1,1,128,128)
y_test_tensor = torch.tensor(y_test)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [4]:
class CNN(nn.Module):
    def __init__(self, number_of_emotions = 4):
        super(CNN, self).__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, kernel_size=5, padding=2),  # Adjust based on your input 64x64
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # nn.Conv2d(8, 16, kernel_size=3, padding=1),  # Adjust based on your input 64x64
            # nn.ReLU(),
            # nn.MaxPool2d(2, 2),
        )

        self.fc1 = nn.Linear(16 * 32 * 32, 64)  # Adjust this
        self.dropout = nn.Dropout(p = 0.35)
        self.fc2 = nn.Linear(64, number_of_emotions)  # Adjust number_of_emotions based on your data

    def forward(self, x):
        x = self.conv_layers(x)
        
        x = x.view(x.size(0), -1)  # Flatten the tensor

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
model = CNN()
n_parameters = sum(p.numel() for p in model.parameters())
n_parameters

1052324