# Use H97_1__0_1_2 vs:

- H97_1__0_1

- H97_1__0_2

- H97_1__1_2

# Data train 3 model: 

- H97_1__0_1

- H97_1__0_2

- H97_1__1_2

Form:
```markdown
train
    |___B2
        |___record1: [xs1, xs2, xs3]
        |___record2: [xs1, xs2, xs3]
    |___B5
    |___B6
valid
    |___B2
    |___B5
    |___B6
```

# Download code and setup

In [None]:
!git clone https://github.com/Harito97/AI_Product_ThyroidCancerClassifier.git
%cd AI_Product_ThyroidCancerClassifier
!git pull


In [None]:
%pip install wandb

In [None]:
import wandb

wandb.login()

# Create data train model:

```markdown
train
    |___B2: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
    |___B5: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
    |___B6: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
valid
    |___B2: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
    |___B5: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
    |___B6: [[xs1, xs2, xs3], [xs1, xs2, xs3], ...]
```

## Load dataver1

In [None]:
import os
import pandas as pd

def count_images_per_class(folder_path):
    image_counts = {}
    for subfolder in ['B2', 'B5', 'B6']:
        folder = os.path.join(folder_path, subfolder)
        if os.path.isdir(folder):
            num_images = len([f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))])
            image_counts[subfolder] = num_images
    return image_counts

# Đường dẫn đến thư mục dữ liệu
train_folder = '/dataver1/train'
valid_folder = '/dataver1/valid'

# Đếm số ảnh ở mỗi nhãn
train_counts = count_images_per_class(train_folder)
valid_counts = count_images_per_class(valid_folder)

# Chuyển đổi dữ liệu thành DataFrame và in kết quả
df_train = pd.DataFrame(list(train_counts.items()), columns=['Class', 'Number of Images'])
df_valid = pd.DataFrame(list(valid_counts.items()), columns=['Class', 'Number of Images'])

print("Train Data")
print(df_train)
print("Validation Data")
print(df_valid)

# Weight of each class
import numpy as np

# Tính trọng số cho hàm loss
total_train_images = sum(train_counts.values())
class_weights = {
    class_name: total_train_images / (len(train_counts) * num_images)
    for class_name, num_images in train_counts.items()
}

# In trọng số
print("Class Weights:", class_weights)

In [None]:
from PIL import Image
import random


class CustomDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = []
        self.labels = []
        for label, subfolder in enumerate(["B2", "B5", "B6"]):
            folder = os.path.join(folder_path, subfolder)
            for img_name in os.listdir(folder):
                if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
                    self.image_paths.append(os.path.join(folder, img_name))
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


# Tăng cường dữ liệu
transform_train = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
    ]
)

transform_valid = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

# train_dataset = CustomDataset(train_folder, transform=transform_train)
train_dataset = CustomDataset(train_folder, transform=transform_valid)
valid_dataset = CustomDataset(valid_folder, transform=transform_valid)

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

## Get the model H97_1__0_1_2

In [None]:
%cd /path/to/AI_Product_ThyroidCancerClassifier

In [None]:
from model.H97_1__0_1_2 import H97_1__0_1_2

h97_1__0_1_2 = H97_1__0_1_2()
h97_1__0_1_2.load_state_dict(
    torch.load("output_ex/best_model_CNN_B2_B5_B6_dataver1_trainx1_30epoch.pth")
)

In [None]:
%matplotlib inline
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, auc, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np
import torch

def evaluate_model(model, valid_loader):
    model.load_state_dict(torch.load('output_ex/best_model_CNN_B2_B5_B6_dataver1_trainx1_30epoch.pth'))
    model.eval()
    all_labels, all_preds, all_preds_prob = [], [], []
    
    output_of_H97_1__0_1_2 = {x: [], y: []}
    
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)

            output_of_H97_1__0_1_2[x].append(outputs)
            output_of_H97_1__0_1_2[y].append(labels)

            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            preds_prob = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_preds_prob.extend(preds_prob)
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    # Log metrics to WandB
    wandb.log({'accuracy': accuracy, 'f1_score': f1})
    
    print(f'Accuracy: {accuracy:.4f}')
    print(f'F1 Score: {f1:.4f}')

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['B2', 'B5', 'B6'])
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()
    
    # ROC AUC
    all_labels_bin = label_binarize(all_labels, classes=[0, 1, 2])
    all_preds_prob = np.array(all_preds_prob)
    
    plt.figure(figsize=(10, 8))
    for i in range(all_labels_bin.shape[1]):
        fpr, tpr, _ = roc_curve(all_labels_bin[:, i], all_preds_prob[:, i])
        plt.plot(fpr, tpr, label=f'Class {i} (area = {auc(fpr, tpr):.2f})')
    
    plt.title('ROC Curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc='lower right')
    plt.show()

    return output_of_H97_1__0_1_2

In [None]:
wandb.init(project="AI_Product_ThyroidCancerClassifier", name='CNN_B2_B5_B6_dataver1_test_on_train')
output_of_H97_1__0_1_2_on_trainset = evaluate_model(model, train_loader)
wandb.finish()

In [None]:
wandb.init(project="AI_Product_ThyroidCancerClassifier", name='CNN_B2_B5_B6_dataver1_test_on_train')
output_of_H97_1__0_1_2_on_validset = evaluate_model(model, valid_loader)
wandb.finish()

In [None]:
print('Output of H97_1__0_1_2 on train set:', output_of_H97_1__0_1_2_on_trainset)
print('Output of H97_1__0_1_2 on valid set:', output_of_H97_1__0_1_2_on_validset)

In [None]:
# Save the train set output
with open('output_ex/output_of_H97_1__0_1_2_on_trainset.pkl', 'wb') as f:
    pickle.dump(output_of_H97_1__0_1_2_on_trainset, f)

# Save the validation set output
with open('output_ex/output_of_H97_1__0_1_2_on_validset.pkl', 'wb') as f:
    pickle.dump(output_of_H97_1__0_1_2_on_validset, f)

# To read the saved outputs later
with open('output_ex/output_of_H97_1__0_1_2_on_trainset.pkl', 'rb') as f:
    loaded_trainset_output = pickle.load(f)

with open('output_ex/output_of_H97_1__0_1_2_on_validset.pkl', 'rb') as f:
    loaded_validset_output = pickle.load(f)

# Print to verify
print(loaded_trainset_output)
print(loaded_validset_output)

# Train H97_1__i_j

## Train H97_1__0_1

In [None]:
import torch
import numpy as np

def filter_data(output_dict, labels_to_include):
    filtered_outputs = {x: [], y: []}
    for outputs, labels in zip(output_dict[x], output_dict[y]):
        mask = np.isin(labels.cpu().numpy(), labels_to_include)
        filtered_outputs[x].append(outputs[mask])
        filtered_outputs[y].append(labels[mask])
    return filtered_outputs

# Filter to include only labels 0 and 1
labels_to_include = [0, 1]
filtered_trainset_output = filter_data(output_of_H97_1__0_1_2_on_trainset, labels_to_include)
filtered_validset_output = filter_data(output_of_H97_1__0_1_2_on_validset, labels_to_include)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

def prepare_dataloader(filtered_output, batch_size=32):
    all_outputs = torch.cat(filtered_output[x])
    all_labels = torch.cat(filtered_output[y])
    dataset = TensorDataset(all_outputs, all_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

train_loader_filtered = prepare_dataloader(filtered_trainset_output)
valid_loader_filtered = prepare_dataloader(filtered_validset_output)

In [None]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import torch

# Function to extract labels from DataLoader
def extract_labels_from_dataloader(dataloader):
    all_labels = []
    for _, labels in dataloader:
        all_labels.extend(labels.cpu().numpy())
    return np.array(all_labels)

# Extract labels from train_loader_filtered
train_labels = extract_labels_from_dataloader(train_loader_filtered)

# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

print("Class Weights:", class_weights_dict)
class_weights = class_weights_dict

# Convert class weights to a tensor
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).cuda()

In [None]:
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize

# Khởi tạo WandB
wandb.init(
    project="AI_Product_ThyroidCancerClassifier",
    name="h97_1__0_1_dataver1",
)


def train(model, train_loader, valid_loader, num_epochs=100, patience=15):
    criterion = nn.CrossEntropyLoss(
        weight=torch.tensor(list(class_weights.values()), dtype=torch.float32).cuda()
    )
    optimizer = Adam(model.parameters(), lr=1e-4)
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        valid_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_valid_loss = valid_loss / len(valid_loader)

        # Early stopping
        if epoch_valid_loss < best_loss:
            best_loss = epoch_valid_loss
            torch.save(
                model.state_dict(),
                "output_ex/best_model_h97_1__0_1_dataver1.pth",
            )
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

        # Tính toán các chỉ số
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="weighted")
        wandb.log(
            {
                "epoch": epoch,
                "accuracy": accuracy,
                "f1_score": f1,
                "train_loss": epoch_loss,
                "valid_loss": epoch_valid_loss,
            }
        )

    wandb.finish()


# Train model
from model.H97_1__i_j import H97_1__i_j

h97_1__0_1 = H97_1__i_j()
h97_1__0_1.cuda()
train(h97_1__0_1, train_loader_filtered, valid_loader_filtered)

In [None]:
%matplotlib inline
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, auc, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np

def evaluate_model(model, valid_loader):
    model.load_state_dict(torch.load('output_ex/best_model_h97_1__0_1_dataver1.pth'))
    model.eval()
    all_labels = []
    all_preds = []
    all_preds_prob = []
    
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            preds_prob = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_preds_prob.extend(preds_prob)
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    # Log metrics to WandB
    wandb.log({'accuracy': accuracy, 'f1_score': f1})
    
    print(f'Accuracy: {accuracy:.4f}')
    print(f'F1 Score: {f1:.4f}')

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['B2', 'B5', 'B6'])
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()
    
    # ROC AUC
    all_labels_bin = label_binarize(all_labels, classes=[0, 1, 2])
    all_preds_prob = np.array(all_preds_prob)
    
    plt.figure(figsize=(10, 8))
    for i in range(all_labels_bin.shape[1]):
        fpr, tpr, _ = roc_curve(all_labels_bin[:, i], all_preds_prob[:, i])
        plt.plot(fpr, tpr, label=f'Class {i} (area = {auc(fpr, tpr):.2f})')
    
    plt.title('ROC Curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc='lower right')
    plt.show()

wandb.init(project="AI_Product_ThyroidCancerClassifier", name='h97_1__0_1_dataver1_test')
evaluate_model(h97_1__0_1, valid_loader_filtered)
wandb.finish()

## Train H97_1__1_2

In [None]:
import torch
import numpy as np

def filter_data(output_dict, labels_to_include):
    filtered_outputs = {x: [], y: []}
    for outputs, labels in zip(output_dict[x], output_dict[y]):
        mask = np.isin(labels.cpu().numpy(), labels_to_include)
        filtered_outputs[x].append(outputs[mask])
        filtered_outputs[y].append(labels[mask])
    return filtered_outputs

# Filter to include only labels 1 and 2
labels_to_include = [1, 2]
filtered_trainset_output = filter_data(output_of_H97_1__0_1_2_on_trainset, labels_to_include)
filtered_validset_output = filter_data(output_of_H97_1__0_1_2_on_validset, labels_to_include)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

def prepare_dataloader(filtered_output, batch_size=32):
    all_outputs = torch.cat(filtered_output[x])
    all_labels = torch.cat(filtered_output[y])
    dataset = TensorDataset(all_outputs, all_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

train_loader_filtered = prepare_dataloader(filtered_trainset_output)
valid_loader_filtered = prepare_dataloader(filtered_validset_output)

In [None]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import torch

# Function to extract labels from DataLoader
def extract_labels_from_dataloader(dataloader):
    all_labels = []
    for _, labels in dataloader:
        all_labels.extend(labels.cpu().numpy())
    return np.array(all_labels)

# Extract labels from train_loader_filtered
train_labels = extract_labels_from_dataloader(train_loader_filtered)

# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

print("Class Weights:", class_weights_dict)
class_weights = class_weights_dict

# Convert class weights to a tensor
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).cuda()

In [None]:
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize

# Khởi tạo WandB
wandb.init(
    project="AI_Product_ThyroidCancerClassifier",
    name="h97_1__1_2_dataver1",
)


def train(model, train_loader, valid_loader, num_epochs=100, patience=15):
    criterion = nn.CrossEntropyLoss(
        weight=torch.tensor(list(class_weights.values()), dtype=torch.float32).cuda()
    )
    optimizer = Adam(model.parameters(), lr=1e-4)
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        valid_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_valid_loss = valid_loss / len(valid_loader)

        # Early stopping
        if epoch_valid_loss < best_loss:
            best_loss = epoch_valid_loss
            torch.save(
                model.state_dict(),
                "output_ex/best_model_h97_1__1_2_dataver1.pth",
            )
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

        # Tính toán các chỉ số
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="weighted")
        wandb.log(
            {
                "epoch": epoch,
                "accuracy": accuracy,
                "f1_score": f1,
                "train_loss": epoch_loss,
                "valid_loss": epoch_valid_loss,
            }
        )

    wandb.finish()


# Train model
from model.H97_1__i_j import H97_1__i_j

h97_1__1_2 = H97_1__i_j()
h97_1__1_2.cuda()
train(h97_1__1_2, train_loader_filtered, valid_loader_filtered)

## Train H97_1__0_2

In [None]:
import torch
import numpy as np

def filter_data(output_dict, labels_to_include):
    filtered_outputs = {x: [], y: []}
    for outputs, labels in zip(output_dict[x], output_dict[y]):
        mask = np.isin(labels.cpu().numpy(), labels_to_include)
        filtered_outputs[x].append(outputs[mask])
        filtered_outputs[y].append(labels[mask])
    return filtered_outputs

# Filter to include only labels 0 and 2
labels_to_include = [0, 2]
filtered_trainset_output = filter_data(output_of_H97_1__0_1_2_on_trainset, labels_to_include)
filtered_validset_output = filter_data(output_of_H97_1__0_1_2_on_validset, labels_to_include)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

def prepare_dataloader(filtered_output, batch_size=32):
    all_outputs = torch.cat(filtered_output[x])
    all_labels = torch.cat(filtered_output[y])
    dataset = TensorDataset(all_outputs, all_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

train_loader_filtered = prepare_dataloader(filtered_trainset_output)
valid_loader_filtered = prepare_dataloader(filtered_validset_output)

In [None]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import torch

# Function to extract labels from DataLoader
def extract_labels_from_dataloader(dataloader):
    all_labels = []
    for _, labels in dataloader:
        all_labels.extend(labels.cpu().numpy())
    return np.array(all_labels)

# Extract labels from train_loader_filtered
train_labels = extract_labels_from_dataloader(train_loader_filtered)

# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

print("Class Weights:", class_weights_dict)
class_weights = class_weights_dict
# Convert class weights to a tensor
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).cuda()

In [None]:
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize

# Khởi tạo WandB
wandb.init(
    project="AI_Product_ThyroidCancerClassifier",
    name="h97_1__0_2_dataver1",
)


def train(model, train_loader, valid_loader, num_epochs=100, patience=15):
    criterion = nn.CrossEntropyLoss(
        weight=torch.tensor(list(class_weights.values()), dtype=torch.float32).cuda()
    )
    optimizer = Adam(model.parameters(), lr=1e-4)
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        valid_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_valid_loss = valid_loss / len(valid_loader)

        # Early stopping
        if epoch_valid_loss < best_loss:
            best_loss = epoch_valid_loss
            torch.save(
                model.state_dict(),
                "output_ex/best_model_h97_1__0_2_dataver1.pth",
            )
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

        # Tính toán các chỉ số
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="weighted")
        wandb.log(
            {
                "epoch": epoch,
                "accuracy": accuracy,
                "f1_score": f1,
                "train_loss": epoch_loss,
                "valid_loss": epoch_valid_loss,
            }
        )

    wandb.finish()


# Train model
from model.H97_1__i_j import H97_1__i_j

h97_1__0_2 = H97_1__i_j()
h97_1__0_2.cuda()
train(h97_1__0_2, train_loader_filtered, valid_loader_filtered)

In [None]:
%matplotlib inline
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, auc, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np

def evaluate_model(model, valid_loader):
    model.load_state_dict(torch.load('output_ex/best_model_h97_1__0_2_dataver1.pth'))
    model.eval()
    all_labels = []
    all_preds = []
    all_preds_prob = []
    
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            preds_prob = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_preds_prob.extend(preds_prob)
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    # Log metrics to WandB
    wandb.log({'accuracy': accuracy, 'f1_score': f1})
    
    print(f'Accuracy: {accuracy:.4f}')
    print(f'F1 Score: {f1:.4f}')

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['B2', 'B5', 'B6'])
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()
    
    # ROC AUC
    all_labels_bin = label_binarize(all_labels, classes=[0, 1, 2])
    all_preds_prob = np.array(all_preds_prob)
    
    plt.figure(figsize=(10, 8))
    for i in range(all_labels_bin.shape[1]):
        fpr, tpr, _ = roc_curve(all_labels_bin[:, i], all_preds_prob[:, i])
        plt.plot(fpr, tpr, label=f'Class {i} (area = {auc(fpr, tpr):.2f})')
    
    plt.title('ROC Curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc='lower right')
    plt.show()

wandb.init(project="AI_Product_ThyroidCancerClassifier", name='h97_1__0_2_dataver1_test')
evaluate_model(h97_1__0_2, valid_loader_filtered)
wandb.finish()