Обучение EfficientNet

In [None]:
import os
import json
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from efficientnet_pytorch import EfficientNet
from torch import nn, optim
from tqdm import tqdm

In [None]:
class MedicalPPEDataset(Dataset):
    def __init__(self, coco_annotation_file, image_dir, transform=None):
        with open(coco_annotation_file, 'r') as f:
            self.coco_data = json.load(f)
        self.image_dir = image_dir
        self.transform = transform
        
        # Create a mapping of image_id to file_name
        self.image_id_to_file = {img['id']: img['file_name'] for img in self.coco_data['images']}
        
        # Create a mapping of image_id to category_id (assuming single class per image)
        self.image_id_to_category = {}
        for ann in self.coco_data['annotations']:
            self.image_id_to_category[ann['image_id']] = ann['category_id']
        
        self.categories = self.coco_data['categories']
        self.num_classes = len(self.categories)

    def __len__(self):
        return len(self.coco_data['images'])

    def __getitem__(self, idx):
        img_info = self.coco_data['images'][idx]
        img_path = os.path.join(self.image_dir, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.image_id_to_category[img_info['id']]
        return image, label

In [None]:
# Set up data transforms
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Create datasets
train_dataset = MedicalPPEDataset('path/to/train_annotations.json', 'path/to/train_images', transform=data_transforms)
val_dataset = MedicalPPEDataset('path/to/val_annotations.json', 'path/to/val_images', transform=data_transforms)

In [None]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
# Set up the model
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=train_dataset.num_classes)

In [None]:
# Set up loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training function
def train_model(model, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset)

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(val_dataset)
        epoch_acc = running_corrects.double() / len(val_dataset)

        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

In [None]:
# Train the model
trained_model = train_model(model, criterion, optimizer, num_epochs=10)

# Save the trained model
torch.save(trained_model.state_dict(), 'medical_ppe_efficientnet.pth')

# Function to test the model on a single image
def test_single_image(model, image_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    image = Image.open(image_path).convert('RGB')
    image = data_transforms(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)

    return train_dataset.categories[preds.item()]['name']

# Example usage of the test function
test_image_path = 'path/to/test_image.jpg'
predicted_class = test_single_image(trained_model, test_image_path)
print(f'Predicted class: {predicted_class}')