In [1]:
import os
import sys
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.patches as patches
import pandas as pd
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
from torch.utils.data import Dataset
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision
from torchvision.models import efficientnet_b0
from torchvision.transforms import ToTensor, Compose, ConvertImageDtype
import torch.nn.functional as F

from Data_manager import EfficientDataset, eff_collate_fn

In [2]:
class EfficientNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.encoder = efficientnet_b0(weights='DEFAULT').features
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, num_classes, kernel_size=1)
        )
    def forward(self, x):
        feat = self.encoder(x)
        out = self.decoder(feat)
        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=False)
        return out

In [3]:
train_root_dir = '../data/Hesperidine_u87m6_5_10_g_neg'
train_dataset = EfficientDataset(train_root_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False, collate_fn=eff_collate_fn)

valid_root_dir = '../data/Control_u87m6_5_g_neg'
valid_dataset = EfficientDataset(valid_root_dir)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=eff_collate_fn)

In [None]:
device = 'cuda'
num_epochs = 15
num_classes = 3
model = EfficientNet(num_classes=num_classes)

model.to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
criterion = nn.CrossEntropyLoss()

In [None]:
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, targets = batch
        images = images.to(device)
        targets = [target.to(device) for target in targets]
        # Обучение
        preds = model(images)
        targets_tensor = torch.stack(targets)
        loss = criterion(preds, targets_tensor)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}")
    
    # Валидация
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"):
            images, targets = batch
            images = images.to(device)
            targets = [target.to(device) for target in targets]
            preds = model(images)
            targets_tensor = torch.stack(targets)
            loss = criterion(preds, targets_tensor)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(valid_loader)
    val_losses.append(avg_val_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {avg_val_loss:.4f}")
    
    scheduler.step()

In [None]:
df = pd.DataFrame({
    'Epoch': range(1, num_epochs + 1),
    'Train Loss': train_losses,
    'Val Loss': val_losses
})
fig = px.line(df, x='Epoch', y=['Train Loss', 'Val Loss'], title='Training and Validation Loss Curves')
fig.show()

In [None]:
torch.save(model.state_dict(), '../trained/test_efficient.pth')
print("Модель сохранена как test_efficient.pth")

# TEST

In [None]:
def visualize_prediction(image, prediction):

    image = image.permute(1, 2, 0).cpu().numpy()  
    pred_mask = torch.argmax(prediction, dim=0).cpu().numpy()
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(image)
    
    masked = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax.imshow(masked, alpha=1.0, cmap='jet')  
    
    plt.axis('off')
    plt.show()