In [4]:

import os
import glob
from sklearn.model_selection import train_test_split
from torchvision.models import efficientnet_b0
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
from PIL import Image
from sklearn.metrics import classification_report



In [None]:

base_path = "/content/drive/My Drive/datasaur/case3-datasaur-photo/techosmotr/techosmotr/train"
real_path = os.path.join(base_path, "pravilniye(correct)/0-correct/*.*")
fake_paths = [
    os.path.join(base_path, f"fictivniye(fictitious)/{subclass}/*.*")
    for subclass in ["1-not-on-the-brake-stand", "2-from-the-screen", "3-from-the-screen+photoshop", "4-photoshop"]
]

real_images = glob.glob(real_path)
fake_images = [img for path in fake_paths for img in glob.glob(path)]

all_images = real_images + fake_images
all_labels = [0] * len(real_images) + [i for i in range(1, 5) for _ in glob.glob(fake_paths[i-1])]

X_train, X_val, y_train, y_val = train_test_split(all_images, all_labels, test_size=0.2, random_state=42)

In [None]:
model = efficientnet_b0(pretrained=True)

num_ftrs = model.classifier[1].in_features

model.classifier[1] = nn.Linear(num_ftrs, 5)  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

train_dataset = CustomDataset(X_train, y_train, transform=data_transforms)
val_dataset = CustomDataset(X_val, y_val, transform=data_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}")

In [None]:
model.eval()
all_preds = []
all_true = []
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().tolist())
        all_true.extend(labels.cpu().tolist())

print(classification_report(all_true, all_preds))


In [None]:
test_path = "/content/drive/My Drive/datasaur/case3-datasaur-photo/techosmotr/techosmotr/test/*.*"
test_images = glob.glob(test_path)


In [None]:
test_dataset = CustomDataset(test_images, [0]*len(test_images), transform=data_transforms)  # dummy labels since we don't have true labels for test data
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
model.eval()
all_preds = []
with torch.no_grad():
    for images in test_loader:
        images = images[0].to(device)  
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().tolist())


In [None]:
file_indices = [os.path.basename(img_path).split('.')[0] for img_path in test_images]
submission_data = list(zip(file_indices, all_preds))

import csv
with open('submission.csv', 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["file_index", "class"])  
    csvwriter.writerows(submission_data)