In [86]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv  # Changed from DenseSAGEConv
from torch_geometric.nn import global_mean_pool  # For sparse pooling
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import warnings
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
from itertools import groupby
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore')

# Define data directory
data_dir = 'data/SCOP'  # Base directory for SCOP data

# Rest of the model code remains the same...

In [90]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from Bio import PDB
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

class SparseSCOPDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        # Set up the root path first
        self.root = root

        # Now we can set up the class_info_path
        self.class_info_path = os.path.join(root, 'raw/class_info.csv')

        # Dictionary to map SCOP classes to indices
        self.class_mapping = {
            'a': 0,  # All-alpha
            'b': 1,  # All-beta
            'c': 2,  # Alpha/beta
            'd': 3,  # Alpha+beta
            'e': 4,  # Multi-domain
            'f': 5,  # Membrane
            'g': 6   # Small proteins
        }

        # Dictionary to map amino acids to indices
        self.amino_acids = {
            'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
            'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
            'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
            'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
            'UNK': 20  # Unknown amino acid
        }

        # Load class information
        if os.path.exists(self.class_info_path):
            self.class_info = pd.read_csv(self.class_info_path)
            print(f"Found class info file with {len(self.class_info)} entries")
        else:
            print(f"Warning: class_info.csv not found at {self.class_info_path}")
            self.class_info = None

        # Initialize the base class last
        super().__init__(root, transform, pre_transform, pre_filter)

    @property  # This decorator was missing!
    def raw_file_names(self):
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property  # This decorator was missing!
    def processed_file_names(self):
        return ['data.pt']


    @property
    def raw_file_names(self):
        """List of raw file names in the dataset."""
        if not os.path.exists(os.path.join(self.root, 'raw')):
            return []
        return [f for f in os.listdir(os.path.join(self.root, 'raw'))
                if f.endswith('.pdb')]

    @property
    def processed_file_names(self):
        """List of processed file names in the dataset."""
        return ['data.pt']

    def download(self):
        """Download the dataset."""
        pass  # We already have the files

    from collections import Counter

    def process(self):
        """Process the raw data into the internal format with comprehensive logging."""
        data_list = []
        parser = PDB.PDBParser(QUIET=True)
        processed_count = 0
        skipped_count = 0
        skipped_reasons = {}

        # Log total number of entries in class_info
        print(f"Total entries in class_info: {len(self.class_info)}")

        # Verify raw directory contents
        raw_dir = os.path.join(self.root, 'raw')
        raw_files = [f for f in os.listdir(raw_dir) if f.endswith('.pdb')]
        print(f"Total PDB files in raw directory: {len(raw_files)}")

        for idx, row in self.class_info.iterrows():
            pdb_id = str(row['scop_id'])

            try:
                # Validate class mapping
                class_label = row['class']
                if class_label not in self.class_mapping:
                    skipped_count += 1
                    skipped_reasons[pdb_id] = f"Invalid class: {class_label}"
                    continue

                class_label = self.class_mapping[class_label]

                # Load structure
                pdb_file = f"{pdb_id}.pdb"
                pdb_path = os.path.join(raw_dir, pdb_file)

                # Check if file exists
                if not os.path.exists(pdb_path):
                    skipped_count += 1
                    skipped_reasons[pdb_id] = "PDB file not found"
                    continue

                structure = parser.get_structure('protein', pdb_path)
                model = structure[0]

                # Get residues and create features
                residues = list(model.get_residues())

                # Skip if too few or too many residues
                if len(residues) < 10 or len(residues) > 1400:
                    skipped_count += 1
                    skipped_reasons[pdb_id] = f"Invalid residue count: {len(residues)}"
                    continue

                # Create node features
                node_features = []
                for residue in residues:
                    features = self._get_residue_features(residue)
                    node_features.append(features)

                # Create edges with 5Å cutoff
                edges = []
                for i in range(len(residues)):
                    for j in range(i+1, len(residues)):
                        if 'CA' in residues[i] and 'CA' in residues[j]:
                            ca_i = residues[i]['CA'].get_coord()
                            ca_j = residues[j]['CA'].get_coord()
                            dist = np.linalg.norm(ca_i - ca_j)
                            if dist < 5.0:  # 5Å cutoff
                                edges.append([i, j])
                                edges.append([j, i])  # Add both directions

                if len(edges) == 0:
                    skipped_count += 1
                    skipped_reasons[pdb_id] = "No edges found"
                    continue

                # Create PyG Data object
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                x = torch.tensor(node_features, dtype=torch.float)
                y = torch.tensor([class_label], dtype=torch.long)

                data = Data(
                    x=x,
                    edge_index=edge_index,
                    y=y,
                    num_nodes=len(residues)
                )

                # Additional filtering if needed
                if self.pre_filter is not None and not self.pre_filter(data):
                    skipped_count += 1
                    skipped_reasons[pdb_id] = "Failed pre-filter"
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)
                processed_count += 1

            except Exception as e:
                skipped_count += 1
                skipped_reasons[pdb_id] = f"Processing error: {str(e)}"
                continue

        # Detailed logging of skipped reasons
        print(f"\nProcessing Summary:")
        print(f"Total processed: {processed_count}")
        print(f"Total skipped: {skipped_count}")
        print("\nSkipped Reasons:")
        for reason, count in Counter(skipped_reasons.values()).most_common():
            print(f"{reason}: {count}")

        # Optional: Print some skipped PDB IDs for investigation
        print("\nSample of skipped PDB IDs:")
        for reason, pdb_ids in groupby(sorted(skipped_reasons.items(), key=lambda x: x[1]), key=lambda x: x[1]):
            print(f"{reason}: {list(pdb_ids)[:5]}")

        if len(data_list) == 0:
            raise RuntimeError("No data was successfully processed!")

        # Save processed data
        torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

        return data_list

    def get(self, idx):
        """Get a specific graph from the dataset."""
        if not hasattr(self, '_data_list'):
            processed_path = os.path.join(self.processed_dir, 'data.pt')
            if not os.path.exists(processed_path):
                print("Warning: Processed data file not found, running processing...")
                self.process()
            # Add weights_only=False to allow loading PyG Data objects
            self._data_list = torch.load(processed_path, weights_only=False)
        return self._data_list[idx]

    def len(self):
        """Return the number of graphs in the dataset."""
        processed_path = os.path.join(self.processed_dir, 'data.pt')
        if not os.path.exists(processed_path):
            print("Warning: Processed data file not found, running processing...")
            self.process()

        if not hasattr(self, '_data_list'):
            # Add weights_only=False here as well
            self._data_list = torch.load(processed_path, weights_only=False)

        print(f"Actual number of processed samples: {len(self._data_list)}")
        return len(self._data_list)

    def _get_residue_features(self, residue):
        """Create feature vector for a residue."""
        # One-hot encode amino acid type
        aa_features = np.zeros(21)  # 20 standard amino acids + UNK
        aa_name = residue.get_resname()
        if aa_name in self.amino_acids:
            aa_features[self.amino_acids[aa_name]] = 1
        else:
            aa_features[self.amino_acids['UNK']] = 1

        # Get CA atom coordinates
        try:
            ca_atom = residue['CA']
            coords = ca_atom.get_coord()
        except:
            coords = np.zeros(3)

        # Combine features
        features = np.concatenate([
            aa_features,  # Amino acid identity (21)
            coords,      # 3D coordinates (3)
        ])

        return features

    @property
    def num_classes(self):
        """Return the number of classes in the dataset."""
        return len(self.class_mapping)

    @property
    def num_features(self):
        """Return the number of node features."""
        return 24  # 21 for amino acids + 3 for coordinates

In [89]:
import os.path as osp
import time
from math import ceil

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.loader import DenseDataLoader

NUM_CLASSES = 7  # SCOP main classes
# First, let's check protein sizes
# First, analyze protein sizes
processed_path = os.path.join(data_dir, 'processed/data.pt')
data_list = torch.load(processed_path, weights_only=False)
sizes = [data.num_nodes for data in data_list]

print(f"Protein size statistics:")
print(f"Min size: {min(sizes)}")
print(f"Max size: {max(sizes)}")
print(f"Mean size: {sum(sizes)/len(sizes):.1f}")
print(f"Median size: {sorted(sizes)[len(sizes)//2]}")
print(f"Number of proteins > 150 residues: {sum(1 for s in sizes if s > 150)}")

# Set max_nodes to 500 since we have proteins up to 1381 residues
max_nodes = 1400

# Create dataset
#dataset = SparseSCOPDataset(
#    root=data_dir,
#    pre_filter=lambda data: data.num_nodes <= max_nodes
#)

#dataset = SparseSCOPDataset(root=data_dir)
print(f"Dataset size: {len(dataset)}")

print(f"\nDataset size: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")


FileNotFoundError: [Errno 2] No such file or directory: 'data/SCOP/processed/data.pt'

In [83]:
# Assuming your dataset is loaded
filtered_dataset = [data for data in dataset if data.num_nodes < 300]

# If you want to create a new dataset object
from torch_geometric.data import Dataset

class FilteredSCOPDataset(Dataset):
    def __init__(self, original_dataset):
        self.data_list = [data for data in original_dataset if data.num_nodes < 300]
        super().__init__(original_dataset.root)

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

# Create the filtered dataset
filtered_dataset = FilteredSCOPDataset(dataset)

# Verify the filtering
print(f"Original dataset size: {len(dataset)}")
print(f"Filtered dataset size: {len(filtered_dataset)}")

# Optional: Check distribution across classes
class_distribution = {}
for data in filtered_dataset:
    class_label = data.y.item()
    class_distribution[class_label] = class_distribution.get(class_label, 0) + 1

print("\nClass distribution in filtered dataset:")
for cls, count in class_distribution.items():
    print(f"Class {cls}: {count} proteins")

Original dataset size: 3420
Filtered dataset size: 2717

Class distribution in filtered dataset:
Class 1: 449 proteins
Class 2: 368 proteins
Class 0: 428 proteins
Class 6: 500 proteins
Class 5: 418 proteins
Class 3: 382 proteins
Class 4: 172 proteins


In [91]:
# Specify the root directory
data_dir = 'data/SCOP'  # Adjust this to your actual data directory path

# Create the dataset (this should trigger the process method)
dataset = SparseSCOPDataset(root=data_dir)
# Create the dataset instance

# Explicitly call the process method
processed_data = dataset.process()

Found class info file with 3500 entries
Total entries in class_info: 3500
Total PDB files in raw directory: 3424


Processing...


KeyboardInterrupt: 

In [84]:
# Create the filtered dataset
from torch_geometric.loader import DataLoader

filtered_dataset = FilteredSCOPDataset(dataset)

# Shuffle and split the filtered dataset
filtered_dataset = filtered_dataset.shuffle()
n = (len(filtered_dataset) + 9) // 10
test_dataset = filtered_dataset[:n]
val_dataset = filtered_dataset[n:2 * n]
train_dataset = filtered_dataset[2 * n:]

# Create data loaders using the filtered datasets
train_loader = DataLoader(train_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
test_loader = DataLoader(test_dataset, batch_size=20)

# Print dataset sizes
print("\nFiltered Dataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# The rest of your training loop remains the same
# You can use these loaders directly in your existing training script


Filtered Dataset splits:
Training samples: 2173
Validation samples: 272
Test samples: 272



Dataset splits:
Training samples: 2736
Validation samples: 342
Test samples: 342


In [79]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# Assuming you've already loaded the dataset
# Load the processed data
processed_path = os.path.join(data_dir, 'processed/data.pt')
data_list = torch.load(processed_path, weights_only=False)

# Reverse class mapping for readable labels
class_mapping_reverse = {
    0: 'a (All-alpha)',
    1: 'b (All-beta)',
    2: 'c (Alpha/beta)',
    3: 'd (Alpha+beta)',
    4: 'e (Multi-domain)',
    5: 'f (Membrane)',
    6: 'g (Small proteins)'
}

# Separate nodes by class
nodes_by_class = {}
for data in data_list:
    class_label = data.y.item()
    if class_label not in nodes_by_class:
        nodes_by_class[class_label] = []
    nodes_by_class[class_label].append(data.num_nodes)

# Create the histogram
plt.figure(figsize=(12, 6))

# Box plot
plt.boxplot([nodes_by_class[key] for key in sorted(nodes_by_class.keys())],
            labels=[class_mapping_reverse[key] for key in sorted(nodes_by_class.keys())])

plt.title('Number of Nodes per SCOP Class', fontsize=16)
plt.xlabel('SCOP Class', fontsize=12)
plt.ylabel('Number of Nodes (Residues)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Save the plot
plt.savefig('nodes_per_class_boxplot.png')
plt.close()

# Print some statistics
print("Node count statistics per class:")
for class_label, nodes in nodes_by_class.items():
    print(f"\n{class_mapping_reverse[class_label]}:")
    print(f"  Count: {len(nodes)}")
    print(f"  Min nodes: {min(nodes)}")
    print(f"  Max nodes: {max(nodes)}")
    print(f"  Mean nodes: {np.mean(nodes):.2f}")
    print(f"  Median nodes: {np.median(nodes):.2f}")

Node count statistics per class:

a (All-alpha):
  Count: 500
  Min nodes: 39
  Max nodes: 656
  Mean nodes: 172.21
  Median nodes: 142.50

b (All-beta):
  Count: 500
  Min nodes: 53
  Max nodes: 548
  Mean nodes: 170.79
  Median nodes: 135.00

c (Alpha/beta):
  Count: 500
  Min nodes: 44
  Max nodes: 815
  Mean nodes: 255.96
  Median nodes: 247.00

d (Alpha+beta):
  Count: 424
  Min nodes: 56
  Max nodes: 640
  Mean nodes: 183.87
  Median nodes: 164.00

e (Multi-domain):
  Count: 496
  Min nodes: 66
  Max nodes: 1395
  Mean nodes: 387.85
  Median nodes: 358.00

f (Membrane):
  Count: 500
  Min nodes: 25
  Max nodes: 746
  Mean nodes: 184.99
  Median nodes: 146.00

g (Small proteins):
  Count: 500
  Min nodes: 20
  Max nodes: 157
  Mean nodes: 62.37
  Median nodes: 56.00


In [80]:
processed_path = os.path.join(data_dir, 'processed/data.pt')
data_list = torch.load(processed_path, weights_only=False)

# Count proteins per class and above 300 nodes
class_counts = {}
above_300_counts = {}

for data in data_list:
    class_label = data.y.item()
    if class_label not in class_counts:
        class_counts[class_label] = 0
        above_300_counts[class_label] = 0

    class_counts[class_label] += 1
    if data.num_nodes > 300:
        above_300_counts[class_label] += 1

print("Total proteins per class:")
for cls, count in class_counts.items():
    print(f"Class {cls}: {count} total, {above_300_counts[cls]} above 300 nodes ({above_300_counts[cls]/count*100:.2f}%)")

Total proteins per class:
Class 0: 500 total, 71 above 300 nodes (14.20%)
Class 1: 500 total, 50 above 300 nodes (10.00%)
Class 2: 500 total, 131 above 300 nodes (26.20%)
Class 3: 424 total, 42 above 300 nodes (9.91%)
Class 4: 496 total, 324 above 300 nodes (65.32%)
Class 5: 500 total, 81 above 300 nodes (16.20%)
Class 6: 500 total, 0 above 300 nodes (0.00%)


In [77]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from math import ceil

class SparseGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim=64, num_classes=7):
        super().__init__()

        # GNN layers with sparse representation
        self.conv1 = SAGEConv(num_features, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, hidden_dim)

        # Batch normalization layers
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn3 = torch.nn.BatchNorm1d(hidden_dim)

        # Final classification layers
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, num_classes)

        # Dropout
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, x, edge_index, batch):
        # First convolution layer
        h1 = self.conv1(x, edge_index)
        h1 = self.bn1(h1)
        h1 = F.relu(h1)
        h1 = self.dropout(h1)

        # Second convolution layer
        h2 = self.conv2(h1, edge_index)
        h2 = self.bn2(h2)
        h2 = F.relu(h2)
        h2 = self.dropout(h2)

        # Third convolution layer
        h3 = self.conv3(h2, edge_index)
        h3 = self.bn3(h3)
        h3 = F.relu(h3)
        h3 = self.dropout(h3)

        # Global mean pooling
        out = global_mean_pool(h3, batch)

        # MLP head
        out = self.lin1(out)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.lin2(out)

        return out


class Net(torch.nn.Module):
    def __init__(self, dataset):
        super().__init__()

        # Create a single GNN model for classification
        self.gnn = SparseGNN(
            num_features=dataset.num_features,
            hidden_dim=64,
            num_classes=dataset.num_classes
        )

    def forward(self, x, edge_index, batch):
        # Single forward pass through the GNN
        x = self.gnn(x, edge_index, batch)
        return F.log_softmax(x, dim=-1)


def setup_device():
    """Determine the best available device."""
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')


def train(model, train_loader, optimizer, device):
    """Training function for the model."""
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(model, loader, device):
    """Evaluation function for the model."""
    model.eval()
    correct = 0
    total = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.edge_index, data.batch).max(dim=1)[1]
        correct += int(pred.eq(data.y.view(-1)).sum())
        total += data.num_graphs

    return correct / total


def main(dataset, train_loader, val_loader, test_loader):
    """Main training and evaluation loop."""
    device = setup_device()
    model = Net(dataset).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

    best_val_acc = 0
    test_acc = 0

    for epoch in range(1, 151):
        train_loss = train(model, train_loader, optimizer, device)
        val_acc = test(model, val_loader, device)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = test(model, test_loader, device)

        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
              f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

    return model, best_val_acc, test_acc

# Note: This function would be called after setting up the dataset, loaders, etc.
# main(dataset, train_loader, val_loader, test_loader)

In [78]:
# Prepare data loaders (as you had before)
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

# Use DataLoader for sparse graphs
train_loader = DataLoader(train_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
test_loader = DataLoader(test_dataset, batch_size=20)

# Train the model
model, best_val_acc, test_acc = main(dataset, train_loader, val_loader, test_loader)

Epoch: 001, Train Loss: 1.8976, Val Acc: 0.1930, Test Acc: 0.2515
Epoch: 002, Train Loss: 1.8763, Val Acc: 0.2281, Test Acc: 0.2485
Epoch: 003, Train Loss: 1.8571, Val Acc: 0.2164, Test Acc: 0.2485
Epoch: 004, Train Loss: 1.8424, Val Acc: 0.2047, Test Acc: 0.2485
Epoch: 005, Train Loss: 1.8294, Val Acc: 0.2193, Test Acc: 0.2485
Epoch: 006, Train Loss: 1.8179, Val Acc: 0.2105, Test Acc: 0.2485
Epoch: 007, Train Loss: 1.8021, Val Acc: 0.2368, Test Acc: 0.2719
Epoch: 008, Train Loss: 1.7708, Val Acc: 0.2719, Test Acc: 0.3099
Epoch: 009, Train Loss: 1.6554, Val Acc: 0.3392, Test Acc: 0.3450
Epoch: 010, Train Loss: 1.5773, Val Acc: 0.3187, Test Acc: 0.3450
Epoch: 011, Train Loss: 1.5352, Val Acc: 0.3480, Test Acc: 0.3392
Epoch: 012, Train Loss: 1.5381, Val Acc: 0.3070, Test Acc: 0.3392
Epoch: 013, Train Loss: 1.5107, Val Acc: 0.3509, Test Acc: 0.3567
Epoch: 014, Train Loss: 1.4955, Val Acc: 0.3363, Test Acc: 0.3567
Epoch: 015, Train Loss: 1.4911, Val Acc: 0.3275, Test Acc: 0.3567
Epoch: 016

In [85]:
from torch_geometric.loader import DataLoader


# Prepare data loaders
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

# Replace your existing DataLoader with this
train_loader = DataLoader(train_dataset,  batch_size=20)
val_loader = DataLoader(val_dataset,  batch_size=20)
test_loader = DataLoader(test_dataset,  batch_size=20)

print("\nDataset splits:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

model, best_val_acc, test_acc = main(dataset, train_loader, val_loader, test_loader)



Dataset splits:
Training samples: 2736
Validation samples: 342
Test samples: 342
Epoch: 001, Train Loss: 1.9036, Val Acc: 0.2135, Test Acc: 0.2164
Epoch: 002, Train Loss: 1.8695, Val Acc: 0.2047, Test Acc: 0.2164
Epoch: 003, Train Loss: 1.8478, Val Acc: 0.2251, Test Acc: 0.2485
Epoch: 004, Train Loss: 1.8373, Val Acc: 0.2515, Test Acc: 0.2251
Epoch: 005, Train Loss: 1.8174, Val Acc: 0.2368, Test Acc: 0.2251
Epoch: 006, Train Loss: 1.7879, Val Acc: 0.2924, Test Acc: 0.2749
Epoch: 007, Train Loss: 1.7467, Val Acc: 0.3304, Test Acc: 0.3041
Epoch: 008, Train Loss: 1.6779, Val Acc: 0.3421, Test Acc: 0.3333
Epoch: 009, Train Loss: 1.6030, Val Acc: 0.3509, Test Acc: 0.3304
Epoch: 010, Train Loss: 1.5866, Val Acc: 0.3509, Test Acc: 0.3304
Epoch: 011, Train Loss: 1.5378, Val Acc: 0.3626, Test Acc: 0.3626
Epoch: 012, Train Loss: 1.5079, Val Acc: 0.4035, Test Acc: 0.4006
Epoch: 013, Train Loss: 1.4732, Val Acc: 0.3801, Test Acc: 0.4006
Epoch: 014, Train Loss: 1.4386, Val Acc: 0.4006, Test Acc: 0

KeyboardInterrupt: 