# Import lib

In [None]:
import numpy as np
import timm
import torch
import seaborn as sns
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, f1_score
import matplotlib.pyplot as plt

# ----------------------------------------------------------

# Class

In [None]:
class Image_(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),  
    transforms.Resize((224, 224)), 
    transforms.ToTensor()
])

# ----------------------------------------------------------

# Data

In [None]:
train_queen = np.load('mfcc_queen_train.npy')
train_nonqueen = np.load('mfcc_NonQueen_train.npy')

val_queen = np.load('mfcc_queen_val.npy')
val_nonqueen = np.load('mfcc_NonQueen_val.npy')

test_queen = np.load('mfcc_queen_test.npy')
test_nonqueen = np.load('mfcc_nonqueen_test.npy')

In [None]:
x_train = np.vstack((train_queen, train_nonqueen))
ones = np.ones(len(train_queen))
zeros = np.zeros(len(train_nonqueen))
y_train = np.concatenate((ones, zeros))

x_val = np.vstack((val_queen, val_nonqueen))
ones = np.ones(len(val_queen))
zeros = np.zeros(len(val_nonqueen))
y_val = np.concatenate((ones, zeros))

x_test = np.vstack((test_queen, test_nonqueen))
ones = np.ones(len(test_queen))
zeros = np.zeros(len(test_nonqueen))
y_test = np.concatenate((ones, zeros))

# ----------------------------------------------------------

# Pre-train

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
seed = 42
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_set = Image_(x_train, y_train, transform=transform)
val_set = Image_(x_val, y_val, transform=transform)
test_set = Image_(x_test, y_test,  transform=transform)

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True) 
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
model = timm.create_model('cait_xxs36_224', pretrained=True)

In [None]:
use_bias = model.patch_embed.proj.bias is not None

model.patch_embed.proj = nn.Conv2d(
    in_channels=1,  
    out_channels=model.patch_embed.proj.out_channels,
    kernel_size=model.patch_embed.proj.kernel_size,
    stride=model.patch_embed.proj.stride,
    padding=model.patch_embed.proj.padding,
    bias=use_bias  
)


model.head = nn.Linear(model.head.in_features, 2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


# pretrain_path = 'STFT_best_model_cait_turn_2.pth'
# state_dict = torch.load(pretrain_path, map_location=device)
# new_state_dict = {}
# for key, value in state_dict.items():
#     new_key = key.replace("module.", "")  # Remove 'module.' prefix if it exists
#     new_state_dict[new_key] = value

# model.load_state_dict(new_state_dict)
# print(f"Pre-trained weights loaded from {pretrain_path}")

if torch.cuda.device_count() > 1:
    print("Sử dụng", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
num_epochs = 100
patience = 20
best_val_loss = float('inf')
early_stop_counter = 0

best_model_wts = model.state_dict()

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_corrects = 0

    for train_images, train_labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        train_images = train_images.to(device)
        train_labels = train_labels.long().to(device)

        optimizer.zero_grad()

        outputs = model(train_images)
        loss = criterion(outputs, train_labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * train_images.size(0)
        preds = torch.argmax(outputs, dim=1)
        train_corrects += torch.sum(preds == train_labels)

    
    train_loss /= len(train_dataloader.dataset)
    train_acc = train_corrects.double() / len(train_dataloader.dataset)

    model.eval()
    val_loss = 0.0
    val_corrects = 0

    with torch.no_grad():
        for val_images, val_labels in tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation"):
            val_images = val_images.to(device)
            val_labels = val_labels.long().to(device)

            val_outputs = model(val_images)
            loss = criterion(val_outputs, val_labels)

            val_loss += loss.item() * val_images.size(0)
            preds = torch.argmax(val_outputs, dim=1)
            val_corrects += torch.sum(preds == val_labels)


    val_loss /= len(val_dataloader.dataset)
    val_acc = val_corrects.double() / len(val_dataloader.dataset)

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'MFCC_best_model_cait_turn_1.pth')
        best_model_wts = model.state_dict()
        early_stop_counter = 0  
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            model.load_state_dict(best_model_wts)  
            break


In [None]:
best_model = model
best_model.load_state_dict(torch.load('MFCC_best_model_cait_turn_1.pth'))


best_model.eval()
test_predictions = []
true_labels = []
with torch.no_grad():
    for test_images, test_labels in tqdm(test_dataloader, desc="Test Set"):
        test_images = test_images.to(device)
        test_labels = test_labels.long().to(device)
        
        outputs = best_model(test_images)
        _, preds = torch.max(outputs, 1)
        
        test_predictions.extend(preds.cpu().numpy())
        true_labels.extend(test_labels.cpu().numpy())


num_correct = sum([1 for i in range(len(test_predictions)) if test_predictions[i] == true_labels[i]])
test_accuracy = num_correct / len(test_predictions)

print(f'Test Accuracy: {test_accuracy * 100:.2f}%')


cm = confusion_matrix(true_labels, test_predictions)
print(f'Confusion Matrix:\n{cm}')


f1 = f1_score(true_labels, test_predictions, average='macro')  
print(f'F1 Score (Macro): {f1:.2f}')

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[str(i) for i in range(len(cm))], yticklabels=[str(i) for i in range(len(cm))])
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()