In [None]:
import torch
from datasets import load_dataset
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

preprocessor = EfficientNetImageProcessor.from_pretrained("google/efficientnet-b0")
model = EfficientNetForImageClassification.from_pretrained("google/efficientnet-b0")

inputs = preprocessor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label]),

In [None]:
import os
from pathlib import Path
from PIL import Image
from tqdm import tqdm

def find_corrupted_images(data_dir, extensions=('.jpg', '.jpeg', '.png')):
    corrupted = []
    for img_path in Path(data_dir).rglob('*'):
        if img_path.suffix.lower() in extensions:
            try:
                img = Image.open(img_path)
                img.load()  # Быстрая проверка целостности
            except (IOError, OSError, Image.UnidentifiedImageError) as e:
                print(f"Битый файл: {img_path} — {e}")
                corrupted.append(str(img_path))
                img = Image.new('RGB', (224, 224), color='black')
    return corrupted

path_data = 'dataset/food11'


for t in ['test', 'train']:
    pt = os.path.join(path_data, t)
    for cls in pt:
        find_corrupted_images(cls)


In [8]:
emun_dir = [('apple_pie', 0), ('cheesecake', 1), ('chicken_curry', 2),
            ('french_fries', 3), ('fried_rice', 4), ('hamburger', 5),
            ('hot_dog', 6), ('ice_cream', 7), ('omelette', 8), 
            ('pizza', 9), ('sushi', 10)]

def scan_and_clean_data(root_dir):
    image_paths = []
    labels = []
    
    for label_name, label_id in emun_dir:
        path = os.path.join(root_dir, label_name)
                
        for filename in tqdm(os.listdir(path), desc=f"{label_name}"):
            file_path = os.path.join(path, filename)
            
            if not filename.lower().endswith(('.jpg', '.jpeg')):
                continue
            
            try:
                img = Image.open(file_path)
                img.load()
                
                if img.mode not in ('RGB', 'L'):
                    continue

                image_paths.append(file_path)
                labels.append(label_id)
            except Exception:
                continue

    print(f"total {len(labels)}")
    print(f"cat {labels.count(0)}, dog {labels.count(1)}")
    return image_paths, labels

path_data = 'dataset/food11'


for t in ['test', 'train']:
    scan_and_clean_data(os.path.join(path_data, t))

apple_pie: 100%|██████████| 100/100 [00:00<00:00, 138.50it/s]
cheesecake: 100%|██████████| 100/100 [00:00<00:00, 116.95it/s]
chicken_curry: 100%|██████████| 100/100 [00:00<00:00, 122.43it/s]
french_fries: 100%|██████████| 100/100 [00:00<00:00, 120.47it/s]
fried_rice: 100%|██████████| 100/100 [00:00<00:00, 112.58it/s]
hamburger: 100%|██████████| 100/100 [00:00<00:00, 103.15it/s]
hot_dog: 100%|██████████| 100/100 [00:00<00:00, 115.50it/s]
ice_cream: 100%|██████████| 100/100 [00:00<00:00, 111.09it/s]
omelette: 100%|██████████| 100/100 [00:00<00:00, 125.92it/s]
pizza: 100%|██████████| 100/100 [00:00<00:00, 125.56it/s]
sushi: 100%|██████████| 100/100 [00:00<00:00, 116.91it/s]


total 1100
cat 100, dog 100


apple_pie: 100%|██████████| 900/900 [00:07<00:00, 113.30it/s]
cheesecake: 100%|██████████| 900/900 [00:09<00:00, 98.24it/s] 
chicken_curry: 100%|██████████| 900/900 [00:08<00:00, 111.77it/s]
french_fries: 100%|██████████| 900/900 [00:07<00:00, 119.11it/s]
fried_rice: 100%|██████████| 900/900 [00:07<00:00, 115.37it/s]
hamburger: 100%|██████████| 900/900 [00:07<00:00, 113.04it/s]
hot_dog: 100%|██████████| 900/900 [00:07<00:00, 116.06it/s]
ice_cream: 100%|██████████| 900/900 [00:07<00:00, 114.61it/s]
omelette: 100%|██████████| 900/900 [00:07<00:00, 116.04it/s]
pizza: 100%|██████████| 900/900 [00:07<00:00, 117.97it/s]
sushi: 100%|██████████| 900/900 [00:08<00:00, 111.69it/s]

total 9900
cat 900, dog 900





In [1]:
from torchvision.models.efficientnet import EfficientNet

In [2]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset 
from torch.utils import data

import torchvision.datasets as datasets

In [3]:
transform = transforms.Compose([
    transforms.Resize((300, 300)), 
    transforms.ToTensor(),                       
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15)
])

In [4]:
transform_val = transforms.Compose([
    transforms.Resize((450, 450)), 
    transforms.ToTensor(),                       
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [13]:
class minclass4torch(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.placeholder_image = Image.new('RGB', (224, 224), color = 'black')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = self.placeholder_image
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            pass

        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
for t in ['test', 'train']:
    all_image_paths, all_labels = scan_and_clean_data(os.path.join(path_data, t))


In [16]:
full_dataset = minclass4torch(all_image_paths, all_labels, transform=transforms)

In [5]:
train_dataset = datasets.ImageFolder(root='dataset/food11/train', transform=transform)
test_dataset = datasets.ImageFolder(root='dataset/food11/test', transform=transform_val)

train_dataloader = data.DataLoader(train_dataset, batch_size=32, num_workers=4)
test_dataloader = data.DataLoader(test_dataset, batch_size=32, num_workers=4)

In [6]:
import torch.nn as nn 
import torchvision.models as models
import torch
from tqdm import tqdm 

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
model_eff = models.efficientnet_b7( weights='DEFAULT')

In [9]:
for param in model_eff.parameters():
    param.requires_grad = False

l = model_eff.classifier[-1].in_features

model_eff.classifier[-1] = nn.Linear(in_features=l, out_features=11)

In [None]:
model_eff.to(device)

In [11]:
import torch.optim as optim
from sklearn.metrics import f1_score, confusion_matrix

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_eff.parameters(), lr=1e-3)

In [13]:
num_epochs = 3

all_targets = []
all_preds = []

def check_f1_score(model, loader, device):
    model.eval()
    all_targets = []
    all_preds = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            
            scores = model(x)
            
            _, predictions = scores.max(1)
            
            all_targets.extend(y.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())
    f1 = f1_score(all_targets, all_preds, average='macro') 
    model.train()
    return f1


for epoch in range(num_epochs):
    
    model_eff.train()
    total_loss = 0.0
    
    train_loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")
    
    for batch_idx, (data, targets) in enumerate(train_loop):
        data = data.to(device)
        targets = targets.to(device)
        scores = model_eff(data)
        loss = criterion(scores, targets)
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
    
        
        train_loop.set_postfix(Loss=loss.item())
    avg_train_loss = total_loss / len(train_dataloader)
    val_f1 = check_f1_score(model_eff, test_dataloader, device)
    print(f"Epoch {epoch+1}")
    print(f"  -> avg loss train: {avg_train_loss:.4f}")
    print(f"  -> F1 test: {val_f1:.4f}")
    
final_f1 = check_f1_score(model_eff, test_dataloader, device)
print(f"final F1 {final_f1:.4f}")

Epoch 1/3 (Train): 100%|██████████| 310/310 [03:06<00:00,  1.66it/s, Loss=1.87]


Epoch 1
  -> avg loss train: 2.8455
  -> F1 test: 0.0512


Epoch 2/3 (Train): 100%|██████████| 310/310 [03:07<00:00,  1.65it/s, Loss=1.89]


Epoch 2
  -> avg loss train: 2.8836
  -> F1 test: 0.0822


Epoch 3/3 (Train): 100%|██████████| 310/310 [03:02<00:00,  1.69it/s, Loss=1.1] 


Epoch 3
  -> avg loss train: 2.8215
  -> F1 test: 0.1184
final F1 0.1184
