In [1]:
import torch
from torchvision.transforms import Compose, Resize, RandomVerticalFlip, RandomHorizontalFlip
from torchsummary import summary

import numpy as np
import pandas as pd
from PIL import Image
import pickle
import os
from sklearn.preprocessing import LabelEncoder

In [2]:
class CASME2Selected(torch.utils.data.Dataset):

    def __init__(self, root, excel, no_frames=10, transform=None):
        super(CASME2Selected, self).__init__()
        self.root = root
        self.excel = excel
        self.no_frames = no_frames
        self.transform = transform
        self.df = pd.read_excel(
                    excel,
                    usecols=["Subject", "Filename", "OnsetFrame", "ApexFrame", "OffsetFrame", "Action Units", "Estimated Emotion"]
                )
        self.df["label"] = self.get_labels()
        self.df["Frame root"] = self.df.apply(lambda x: self.get_paths(x), axis=1)
        self.df["Frames"] = self.df.apply(lambda x: self.get_frames(x["Frame root"]), axis=1)

    def get_paths(self,row):
        sub = str(row["Subject"])
        sub = "sub"+sub if len(sub) == 2 else "sub0"+ sub
        ep = row["Filename"]

        return os.path.join(self.root, sub, ep)   

    def get_frames(self, path):
        frames = list()

        for i in os.scandir(path):
            if i.is_file():
                frames.append(i.path)

        return frames
    
    def get_labels(self):
        self.encoder = LabelEncoder()
        outputs = self.encoder.fit_transform(self.df["Estimated Emotion"])
        return pd.DataFrame(outputs, columns=["label"])
    
    def save_encoder(self, path="./encoder.pkl"):
        with open(path, "wb") as f:
            pickle.dump(self.encoder, f)

    def __len__(self):
        return len(self.df)
    
    def getitem(self,index):
        return self.df.iloc[index]
    
    def __getitem__(self, index):
        item = self.getitem(index)
        frame_paths = item["Frames"][:self.no_frames]
        frames = list()

        for i in frame_paths:
            frame = Image.open(i)
            frame = np.array(frame)
            frame = frame.transpose(2,0,1)
            # frame = frame.reshape(frame.shape[0], frame.shape[1], 3)
            frames.append(frame)

        frames = torch.tensor(np.array(frames), dtype=torch.float32)
        if self.transform:
            frames = self.transform(frames)

        frames = frames.permute(1,0,2,3)

        return frames, torch.tensor(np.eye(7)[item["label"]], dtype=torch.float32)

In [None]:
transform = Compose([
    Resize((200,200)),
    RandomHorizontalFlip(0.3),
    RandomVerticalFlip(0.3)
])

In [30]:
dataset = CASME2Selected(
            root="F:\casme2\CASME2_RAW_selected\CASME2_RAW_selected",
            excel="F:\casme2\CASME2-coding-20140508.xlsx",
            no_frames=20,
            transform=transform
        )

In [31]:
dataset.getitem(1)

Subject                                                              1
Filename                                                       EP03_02
OnsetFrame                                                         131
ApexFrame                                                          139
OffsetFrame                                                        161
Action Units                                                        18
Estimated Emotion                                               others
label                                                                3
Frame root           F:\casme2\CASME2_RAW_selected\CASME2_RAW_selec...
Frames               [F:\casme2\CASME2_RAW_selected\CASME2_RAW_sele...
Name: 1, dtype: object

In [32]:
frames, label = dataset.__getitem__(0)

print("Input shape: ", frames.shape)
print("Image shape: ", frames[0].shape)
print("No. of frames: ", frames.shape[0])
print("Label: ", label)

Input shape:  torch.Size([3, 20, 200, 200])
Image shape:  torch.Size([20, 200, 200])
No. of frames:  3
Label:  tensor([0., 0., 1., 0., 0., 0., 0.])


In [33]:
val_size = int(dataset.__len__() / 4)
train_size = int(dataset.__len__() - val_size)

print(train_size, val_size, train_size + val_size == dataset.__len__())

192 63 True


In [34]:
train_loader, val_loader = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_loader, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_loader, batch_size=4, shuffle=False)

In [35]:
for input, label in train_loader:
    print(input.shape, label)
    break

torch.Size([4, 3, 20, 200, 200]) tensor([[0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.]])


In [36]:
val_loader.__len__(), train_loader.__len__()

(16, 48)

## Model Definition

In [11]:
class MainBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MainBlock, self).__init__()

        self.conv1 = torch.nn.Conv3d(in_channels, in_channels, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv2 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1))

        self.relu = torch.nn.ReLU()

    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))

        return x

In [12]:
class SideBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SideBlock, self).__init__()

        self.conv = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1))
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv(x))

        return x

In [13]:
class RecallBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(RecallBlock, self).__init__()

        self.conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=(3,3,3), padding=(1,1,1))
        self.main_block = MainBlock(in_channels, out_channels)
        self.side_block = SideBlock(in_channels, out_channels)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv(x))

        return self.main_block(x) + self.side_block(x)

In [14]:
class SpecialResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SpecialResidualBlock, self).__init__()

        self.main_block = MainBlock(in_channels, out_channels)
        self.side_block = SideBlock(in_channels, out_channels)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.main_block(x) + self.side_block(x)

In [15]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.main_block = MainBlock(in_channels, out_channels)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.main_block(x) + x

In [16]:
class MicroMotionPredictor(torch.nn.Module):
    def __init__(self, in_channels, output_shape):
        super(MicroMotionPredictor, self).__init__()

        self.conv_in = torch.nn.Conv3d(in_channels, 16, kernel_size=(3,3,3), padding=(1,1,1))
        self.recall_1 = RecallBlock(16, 32)
        self.special_res_1 = SpecialResidualBlock(32, 64)
        self.res = ResidualBlock(64, 64)

        self.conv_out = torch.nn.Conv3d(64, 128, kernel_size=(3,3,3), padding=(1,1,1))

        self.linear_1 = torch.nn.Linear(128, 64)
        self.linear_2 = torch.nn.Linear(64, output_shape)

        self.max_pool = torch.nn.MaxPool3d(kernel_size=(2,2,2)) # kernel_size=(2,2,2), stride=(2,2,2)
        self.global_pool = torch.nn.AdaptiveMaxPool3d((1,1,1))

        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.conv_in(x))

        x = self.max_pool(self.recall_1(x))
        # x = self.max_pool(self.recall_2(x))

        x = self.max_pool(self.special_res_1(x))
        x = self.max_pool(self.res(x))

        x = self.relu(self.conv_out(x))
        x = self.global_pool(x)

        x = x.view(-1, 128)

        x = self.relu(self.linear_1(x))
        x = self.softmax(self.linear_2(x))
        
        return x

In [17]:
class STSTNetSubBlocks(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(STSTNetSubBlocks, self).__init__()

        self.conv_in_3 = torch.nn.Conv3d(in_channels, out_channels*3, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv_in_5 = torch.nn.Conv3d(in_channels, out_channels*5, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv_in_8 = torch.nn.Conv3d(in_channels, out_channels*8, kernel_size=(3,3,3), padding=(1,1,1))

        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool3d(kernel_size=(2,2,2))

    def forward(self, x):
        out_3 = self.pool(self.relu(self.conv_in_3(x)))
        out_5 = self.pool(self.relu(self.conv_in_5(x)))
        out_8 = self.pool(self.relu(self.conv_in_8(x)))

        x = torch.cat([out_3, out_5, out_8], dim=1)

        return x

In [18]:
class STSTNet(torch.nn.Module):
    def __init__(self, in_channels, output_shape):
        super(STSTNet, self).__init__()

        self.block_1 = STSTNetSubBlocks(in_channels, 1)
        self.block_2 = STSTNetSubBlocks(16, 2)
        self.block_3 = STSTNetSubBlocks(32, 4)
        self.block_4 = STSTNetSubBlocks(64, 6)

        self.linear_1 = torch.nn.Linear(96*12*12, 1000)
        self.linear_2 = torch.nn.Linear(1000, 500)
        self.output_layer = torch.nn.Linear(500, output_shape)
        
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.block_1(x))
        x = self.relu(self.block_2(x))
        x = self.relu(self.block_3(x))
        x = self.relu(self.block_4(x))

        x = x.view(-1, 96*12*12)

        x = self.relu(self.linear_1(x))
        x = self.relu(self.linear_2(x))
        x = self.output_layer(x)
        
        return x

In [19]:
class STSTNet2(torch.nn.Module):
    def __init__(self, in_channels, output_shape):
        super(STSTNet2, self).__init__()

        self.conv_1 = torch.nn.Conv3d(in_channels, in_channels, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv_2 = torch.nn.Conv3d(in_channels, in_channels, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv_3 = torch.nn.Conv3d(in_channels, in_channels, kernel_size=(3,3,3), padding=(1,1,1))

        self.parallel_block = STSTNetSubBlocks(in_channels, 1)

        self.linear_1 = torch.nn.Linear(16*12*12, 1000)
        self.linear_2 = torch.nn.Linear(1000, 500)
        self.output_layer = torch.nn.Linear(500, output_shape)

        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool3d(kernel_size=(2,2,2))

    def forward(self, x):
        x = self.pool(self.relu(self.conv_1(x)))
        x = self.pool(self.relu(self.conv_2(x)))
        x = self.pool(self.relu(self.conv_3(x)))

        x = self.parallel_block(x)

        x = x.view(-1, 16*12*12)

        x = self.relu(self.linear_1(x))
        x = self.relu(self.linear_2(x))
        x = self.output_layer(x)

        return x

# Train Model

## Without Augmentation

In [21]:
model = STSTNet2(3, 7).to("cuda")
summary(model, (3, 20, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 3, 20, 200, 200]             246
              ReLU-2      [-1, 3, 20, 200, 200]               0
         MaxPool3d-3      [-1, 3, 10, 100, 100]               0
            Conv3d-4      [-1, 3, 10, 100, 100]             246
              ReLU-5      [-1, 3, 10, 100, 100]               0
         MaxPool3d-6         [-1, 3, 5, 50, 50]               0
            Conv3d-7         [-1, 3, 5, 50, 50]             246
              ReLU-8         [-1, 3, 5, 50, 50]               0
         MaxPool3d-9         [-1, 3, 2, 25, 25]               0
           Conv3d-10         [-1, 3, 2, 25, 25]             246
             ReLU-11         [-1, 3, 2, 25, 25]               0
        MaxPool3d-12         [-1, 3, 1, 12, 12]               0
           Conv3d-13         [-1, 5, 2, 25, 25]             410
             ReLU-14         [-1, 5, 2,

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 10

In [23]:
for epoch in range(epochs):
    model.train()

    train_loss = 0
    train_accuracy = 0
    i = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to("cuda"), labels.to("cuda")
        optimizer.zero_grad()

        outputs = model(inputs)

        # print(outputs.shape, labels.shape)
        outputs = outputs.reshape(-1, 7)
        labels = labels.reshape(-1,7)
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_accuracy += (predicted == torch.argmax(labels, dim=1)).sum().item()

    print(f"Epoch: {epoch+1}/{epochs}, Loss: {train_loss}, Accuracy: {train_accuracy/train_size}")
    
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            outputs = model(inputs)
            outputs = outputs.reshape(-1, 7)
            labels = labels.reshape(-1,7)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
        
        print(f"Epoch: {epoch+1}/{epochs}, Loss: {loss.item()}, Accuracy: {correct/total}")
        print()

Epoch: 1/10, Loss: 83.6700764298439, Accuracy: 0.4010416666666667
Epoch: 1/10, Loss: 2.0001907348632812, Accuracy: 0.2857142857142857

Epoch: 2/10, Loss: 73.34073668718338, Accuracy: 0.3958333333333333
Epoch: 2/10, Loss: 1.6351420879364014, Accuracy: 0.2857142857142857

Epoch: 3/10, Loss: 72.06623417139053, Accuracy: 0.421875
Epoch: 3/10, Loss: 1.885805368423462, Accuracy: 0.2857142857142857

Epoch: 4/10, Loss: 67.4593808054924, Accuracy: 0.4427083333333333
Epoch: 4/10, Loss: 1.3986167907714844, Accuracy: 0.31746031746031744

Epoch: 5/10, Loss: 67.51821452379227, Accuracy: 0.4010416666666667
Epoch: 5/10, Loss: 0.8515465259552002, Accuracy: 0.38095238095238093

Epoch: 6/10, Loss: 63.13419848680496, Accuracy: 0.4583333333333333
Epoch: 6/10, Loss: 1.5551624298095703, Accuracy: 0.3492063492063492

Epoch: 7/10, Loss: 59.17748123407364, Accuracy: 0.4791666666666667
Epoch: 7/10, Loss: 0.9111011028289795, Accuracy: 0.31746031746031744

Epoch: 8/10, Loss: 53.98774394392967, Accuracy: 0.5
Epoch:

In [24]:
torch.save(model.state_dict(), "STSTNet2_without_augmentation.pth")

In [25]:
model = STSTNet(in_channels=3, output_shape=7).to("cuda")
summary(model, (3, 20, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 3, 20, 200, 200]             246
              ReLU-2      [-1, 3, 20, 200, 200]               0
         MaxPool3d-3      [-1, 3, 10, 100, 100]               0
            Conv3d-4      [-1, 5, 20, 200, 200]             410
              ReLU-5      [-1, 5, 20, 200, 200]               0
         MaxPool3d-6      [-1, 5, 10, 100, 100]               0
            Conv3d-7      [-1, 8, 20, 200, 200]             656
              ReLU-8      [-1, 8, 20, 200, 200]               0
         MaxPool3d-9      [-1, 8, 10, 100, 100]               0
 STSTNetSubBlocks-10     [-1, 16, 10, 100, 100]               0
             ReLU-11     [-1, 16, 10, 100, 100]               0
           Conv3d-12      [-1, 6, 10, 100, 100]           2,598
             ReLU-13      [-1, 6, 10, 100, 100]               0
        MaxPool3d-14         [-1, 6, 5,

In [26]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 10

In [27]:
for epoch in range(epochs):
    model.train()

    train_loss = 0
    train_accuracy = 0
    i = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to("cuda"), labels.to("cuda")
        optimizer.zero_grad()

        outputs = model(inputs)

        # print(outputs.shape, labels.shape)
        outputs = outputs.reshape(-1, 7)
        labels = labels.reshape(-1,7)
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_accuracy += (predicted == torch.argmax(labels, dim=1)).sum().item()

    print(f"Epoch: {epoch+1}/{epochs}, Loss: {train_loss}, Accuracy: {train_accuracy/train_size}")
    
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            outputs = model(inputs)
            outputs = outputs.reshape(-1, 7)
            labels = labels.reshape(-1,7)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
        
        print(f"Epoch: {epoch+1}/{epochs}, Loss: {loss.item()}, Accuracy: {correct/total}")
        print()

Epoch: 1/10, Loss: 169.2735168337822, Accuracy: 0.34375
Epoch: 1/10, Loss: 1.8446950912475586, Accuracy: 0.2857142857142857

Epoch: 2/10, Loss: 82.6200498342514, Accuracy: 0.3177083333333333
Epoch: 2/10, Loss: 0.7214848399162292, Accuracy: 0.2857142857142857

Epoch: 3/10, Loss: 78.39912635087967, Accuracy: 0.421875
Epoch: 3/10, Loss: 1.3966236114501953, Accuracy: 0.2857142857142857

Epoch: 4/10, Loss: 74.31995576620102, Accuracy: 0.421875
Epoch: 4/10, Loss: 1.374719262123108, Accuracy: 0.2857142857142857

Epoch: 5/10, Loss: 73.16849058866501, Accuracy: 0.4479166666666667
Epoch: 5/10, Loss: 1.2128146886825562, Accuracy: 0.2857142857142857

Epoch: 6/10, Loss: 73.5774518251419, Accuracy: 0.40625
Epoch: 6/10, Loss: 2.4039254188537598, Accuracy: 0.38095238095238093

Epoch: 7/10, Loss: 84.36823552846909, Accuracy: 0.4583333333333333
Epoch: 7/10, Loss: 1.6961948871612549, Accuracy: 0.30158730158730157

Epoch: 8/10, Loss: 76.07870823144913, Accuracy: 0.3854166666666667
Epoch: 8/10, Loss: 1.354

In [28]:
torch.save(model.state_dict(), "STSTNet_without_augmentation.pth")

## With Augmentation

In [42]:
model = STSTNet2(3, 7).to("cuda")
summary(model, (3, 20, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 3, 20, 200, 200]             246
              ReLU-2      [-1, 3, 20, 200, 200]               0
         MaxPool3d-3      [-1, 3, 10, 100, 100]               0
            Conv3d-4      [-1, 3, 10, 100, 100]             246
              ReLU-5      [-1, 3, 10, 100, 100]               0
         MaxPool3d-6         [-1, 3, 5, 50, 50]               0
            Conv3d-7         [-1, 3, 5, 50, 50]             246
              ReLU-8         [-1, 3, 5, 50, 50]               0
         MaxPool3d-9         [-1, 3, 2, 25, 25]               0
           Conv3d-10         [-1, 3, 2, 25, 25]             246
             ReLU-11         [-1, 3, 2, 25, 25]               0
        MaxPool3d-12         [-1, 3, 1, 12, 12]               0
           Conv3d-13         [-1, 5, 2, 25, 25]             410
             ReLU-14         [-1, 5, 2,

In [43]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 10

In [44]:
for epoch in range(epochs):
    model.train()

    train_loss = 0
    train_accuracy = 0
    i = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to("cuda"), labels.to("cuda")
        optimizer.zero_grad()

        outputs = model(inputs)

        # print(outputs.shape, labels.shape)
        outputs = outputs.reshape(-1, 7)
        labels = labels.reshape(-1,7)
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_accuracy += (predicted == torch.argmax(labels, dim=1)).sum().item()

    print(f"Epoch: {epoch+1}/{epochs}, Loss: {train_loss}, Accuracy: {train_accuracy/train_size}")
    
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            outputs = model(inputs)
            outputs = outputs.reshape(-1, 7)
            labels = labels.reshape(-1,7)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
        
        print(f"Epoch: {epoch+1}/{epochs}, Loss: {loss.item()}, Accuracy: {correct/total}")
        print()

Epoch: 1/10, Loss: 95.16646015644073, Accuracy: 0.359375
Epoch: 1/10, Loss: 1.416420817375183, Accuracy: 0.25396825396825395

Epoch: 2/10, Loss: 77.47435414791107, Accuracy: 0.4010416666666667
Epoch: 2/10, Loss: 1.7649202346801758, Accuracy: 0.3492063492063492

Epoch: 3/10, Loss: 76.7103545665741, Accuracy: 0.4010416666666667
Epoch: 3/10, Loss: 1.8185539245605469, Accuracy: 0.3492063492063492

Epoch: 4/10, Loss: 76.31617385149002, Accuracy: 0.4114583333333333
Epoch: 4/10, Loss: 1.7310296297073364, Accuracy: 0.38095238095238093

Epoch: 5/10, Loss: 73.13253623247147, Accuracy: 0.4375
Epoch: 5/10, Loss: 1.5312503576278687, Accuracy: 0.38095238095238093

Epoch: 6/10, Loss: 71.51688235998154, Accuracy: 0.4375
Epoch: 6/10, Loss: 1.4868758916854858, Accuracy: 0.2857142857142857

Epoch: 7/10, Loss: 69.50566804409027, Accuracy: 0.4739583333333333
Epoch: 7/10, Loss: 1.6099438667297363, Accuracy: 0.31746031746031744

Epoch: 8/10, Loss: 66.64194318652153, Accuracy: 0.5052083333333334
Epoch: 8/10, 

In [45]:
torch.save(model.state_dict(), "STSTNet2_augmentation.pth")

In [46]:
model = STSTNet(in_channels=3, output_shape=7).to("cuda")
summary(model, (3, 20, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 3, 20, 200, 200]             246
              ReLU-2      [-1, 3, 20, 200, 200]               0
         MaxPool3d-3      [-1, 3, 10, 100, 100]               0
            Conv3d-4      [-1, 5, 20, 200, 200]             410
              ReLU-5      [-1, 5, 20, 200, 200]               0
         MaxPool3d-6      [-1, 5, 10, 100, 100]               0
            Conv3d-7      [-1, 8, 20, 200, 200]             656
              ReLU-8      [-1, 8, 20, 200, 200]               0
         MaxPool3d-9      [-1, 8, 10, 100, 100]               0
 STSTNetSubBlocks-10     [-1, 16, 10, 100, 100]               0
             ReLU-11     [-1, 16, 10, 100, 100]               0
           Conv3d-12      [-1, 6, 10, 100, 100]           2,598
             ReLU-13      [-1, 6, 10, 100, 100]               0
        MaxPool3d-14         [-1, 6, 5,

In [47]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 10

In [48]:
for epoch in range(epochs):
    model.train()

    train_loss = 0
    train_accuracy = 0
    i = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to("cuda"), labels.to("cuda")
        optimizer.zero_grad()

        outputs = model(inputs)

        # print(outputs.shape, labels.shape)
        outputs = outputs.reshape(-1, 7)
        labels = labels.reshape(-1,7)
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_accuracy += (predicted == torch.argmax(labels, dim=1)).sum().item()

    print(f"Epoch: {epoch+1}/{epochs}, Loss: {train_loss}, Accuracy: {train_accuracy/train_size}")
    
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            outputs = model(inputs)
            outputs = outputs.reshape(-1, 7)
            labels = labels.reshape(-1,7)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
        
        print(f"Epoch: {epoch+1}/{epochs}, Loss: {loss.item()}, Accuracy: {correct/total}")
        print()

Epoch: 1/10, Loss: 194.9548265337944, Accuracy: 0.3333333333333333
Epoch: 1/10, Loss: 1.1748645305633545, Accuracy: 0.3492063492063492

Epoch: 2/10, Loss: 78.5034487247467, Accuracy: 0.4010416666666667
Epoch: 2/10, Loss: 3.0566306114196777, Accuracy: 0.3492063492063492

Epoch: 3/10, Loss: 77.50202637910843, Accuracy: 0.4114583333333333
Epoch: 3/10, Loss: 1.8844460248947144, Accuracy: 0.3492063492063492

Epoch: 4/10, Loss: 76.05291545391083, Accuracy: 0.4166666666666667
Epoch: 4/10, Loss: 1.5404412746429443, Accuracy: 0.23809523809523808

Epoch: 5/10, Loss: 75.7077499628067, Accuracy: 0.3802083333333333
Epoch: 5/10, Loss: 1.554988145828247, Accuracy: 0.25396825396825395

Epoch: 6/10, Loss: 72.39123672246933, Accuracy: 0.4427083333333333
Epoch: 6/10, Loss: 1.7296024560928345, Accuracy: 0.36507936507936506

Epoch: 7/10, Loss: 73.49071782827377, Accuracy: 0.4427083333333333
Epoch: 7/10, Loss: 1.3788241147994995, Accuracy: 0.36507936507936506

Epoch: 8/10, Loss: 71.99607628583908, Accuracy:

In [49]:
torch.save(model.state_dict(), "STSTNet_augmentation.pth")