In [31]:
import sys
sys.path
sys.path.append('../')

In [32]:
import actlearn as al 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.optim as optim
import torch
from sklearn.metrics import f1_score, accuracy_score
import torch.nn.functional as F
import os
import glob
import torchvision.transforms as transforms
import random

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        
        self.conv = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(16)
        
        self.flat = nn.Flatten()
        self.fc = nn.Linear(8 * 16 * 16, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv(x)))
        x = self.flat(x)
        x = self.fc(x)
        return x

    def predict_proba(self, X):
        logits = self(X)
        return F.softmax(logits, dim=1)
CNN = CNN()

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

def load_images_from_folder_shuffled(base_path, seed=42):
    images = []
    labels = []
    
    classes = sorted([d for d in os.listdir(base_path) 
                     if os.path.isdir(os.path.join(base_path, d))])
    class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
    
    print(f"Найдены классы: {classes}")
    print(f"Сопоставление: {class_to_idx}")
    
    all_data = []
    for class_name in classes:
        class_dir = os.path.join(base_path, class_name)
        image_paths = glob.glob(os.path.join(class_dir, "*.jpg"))
        
        print(f"Класс '{class_name}': {len(image_paths)} изображений")
        
        for img_path in image_paths:
            all_data.append((img_path, class_to_idx[class_name]))
    
    random.seed(seed)
    random.shuffle(all_data)
    
    for i, (img_path, label) in enumerate(all_data):
        try:
            image = Image.open(img_path)
            tensor = transform(image)
            images.append(tensor)
            labels.append(label)
        except Exception as e:
            print(f"Ошибка при загрузке {img_path}: {e}")
    
    return images, labels, classes

      

In [None]:
train_images, train_labels, classes = load_images_from_folder_shuffled("data_AL/BrainTumourMRI/Training")
train_data = torch.stack(train_images)
train_labels = torch.tensor(train_labels)
print(train_data.shape, train_labels.shape)

In [None]:
test_images, test_labels, classes = load_images_from_folder_shuffled("data_AL/BrainTumourMRI/Testing")
test_data = torch.stack(test_images)
test_labels = torch.tensor(test_labels)
print(test_data.shape, test_labels.shape)

In [36]:
alpha = 1
def reduce_dataset(X, y, alpha):
    n_samples = len(X)
    n_keep = int(alpha * n_samples)

    return X[:n_keep], y[:n_keep]

In [None]:
train_data, train_labels = reduce_dataset(train_data, train_labels, alpha)
test_data, test_labels = reduce_dataset(test_data, test_labels, alpha)
print(train_data.shape, train_labels.shape)
print(test_data.shape, test_labels.shape)

In [57]:
AL = al.ActiveLearning(CNN, train_data, train_labels, test_data, test_labels, 
                       strategy="margin", al_type="incremental", init_size=128, 
                       update_size=128, batch_size=128, metric="f1", logs=True, skip=True)

In [None]:
AL.fit(stop_ratio=1)
al.plot_active_learning_results_many(AL, dataset_name="Brain-Tumour-MRI")