In [37]:
import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self, num_classes=2):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.relu5 = nn.ReLU()
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.relu6 = nn.ReLU()
        self.pool6 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.dropout = nn.Dropout(p=0.3)

        self.flatten = nn.Flatten()

        self.dense1 = nn.Linear(4608, 4096)
        self.relu7 = nn.ReLU()

        self.dense2 = nn.Linear(4096, 4096)
        self.relu8 = nn.ReLU()

        self.dense3 = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = self.pool4(self.relu4(self.conv4(x)))
        x = self.pool5(self.relu5(self.conv5(x)))
        x = self.pool6(self.relu6(self.conv6(x)))
        x = self.dropout(x)

        x = self.flatten(x)
        x = self.relu7(self.dense1(x))
        x = self.relu8(self.dense2(x))
        x = self.dense3(x)

        return x


In [38]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, precision_score, f1_score
from torch.utils.data import DataLoader


# Set device
torch.cuda.empty_cache()
import gc
gc.collect()
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the model and move it to the device
model = CNNModel(num_classes=7).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load datasets
train_dataset = datasets.ImageFolder(root="Data_Set_Argho/train/", transform=transform)
test_dataset = datasets.ImageFolder(root="Data_Set_Argho/test/", transform=transform)

# Assuming you have DataLoader for training and testing data
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)



In [39]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    print("#" * 30)
    print(epoch)
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        print(outputs.shape, labels.shape)

        # Ensure that labels are of type torch.long
        labels = labels.long()

        # Calculate CrossEntropyLoss
        loss = criterion(outputs, labels)
        #print(loss.item())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


##############################
0
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([52, 7]) torch.Size([52])
##############################
3
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([52, 7]) torch.Size([52])
##############################
6
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([52, 7]) torch.Size([52])
##############################
9
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch

torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
torch.Size([64, 7]) torch.Size([64])
t

In [40]:
# Initialize empty lists for predictions and labels
all_preds = []
all_labels = []

# Validation
model.eval()
with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        print(f"Batch {batch_idx + 1}/{len(test_loader)} - Output shape: {outputs.shape}, Labels shape: {labels.shape}")

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


Batch 1/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 2/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 3/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 4/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 5/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 6/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 7/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 8/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 9/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 10/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 11/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 12/75 - Output shape: torch.Size([64, 7]), Labels shape: torch.Size([64])
Batch 13/75 - Output shape: torch.Size([64, 7]), 

In [41]:
# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='weighted')
f1 = f1_score(all_labels, all_preds, average='weighted')

print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, F1-Score: {f1:.4f}")

Accuracy: 0.3544, Precision: 0.3493, F1-Score: 0.3320
