In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! unzip  /content/drive/MyDrive/released.zip -d ./data


Archive:  /content/drive/MyDrive/released.zip
replace ./data/test/00cb0c05-992f-4b41-83ed-842e4f5239ba.pkl? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
import pandas as pd
import pickle
from PIL import Image
from sklearn.utils.class_weight import compute_class_weight

class BagDataset(Dataset):
    def __init__(self, bags, labels=None, transform=None):
        self.bags = bags
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        bag = self.bags[idx]
        if self.transform:
            bag = [self.transform(Image.fromarray(image)) for image in bag]
        bag = torch.stack(bag)
        if self.labels is not None:
            label = self.labels[idx]
            return bag, label
        else:
            return bag
from torchvision.models import ResNet50_Weights
class SimpleClassifier(nn.Module):
    def __init__(self):
        super(SimpleClassifier, self).__init__()
        self.resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.mean(x, dim=1)
        x = self.resnet(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

def load_data(train_path, test_path):
    train_bags = []
    train_labels = []
    test_bags = []
    test_ids = []

    for class_label in [0, 1]:
        class_path = os.path.join(train_path, f'class_{class_label}')
        for file_name in os.listdir(class_path):
            with open(os.path.join(class_path, file_name), 'rb') as f:
                bag = pickle.load(f)
                train_bags.append(bag)
                train_labels.append(class_label)

    for file_name in os.listdir(test_path):
        with open(os.path.join(test_path, file_name), 'rb') as f:
            bag = pickle.load(f)
            test_bags.append(bag)
            test_ids.append(file_name.split('.')[0])

    return train_bags, train_labels, test_bags, test_ids

def predict(model, dataloader):
    model.eval()
    predictions = []
    for inputs in dataloader:
        inputs = inputs.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            preds = (outputs > 0.5).cpu().numpy()
            predictions.extend(preds)
    return predictions

if __name__ == '__main__':
    data_dir = './data'
    train_path = os.path.join(data_dir, 'train')
    test_path = os.path.join(data_dir, 'test')

    train_bags, train_labels, test_bags, test_ids = load_data(train_path, test_path)

    transform = transforms.Compose([
        transforms.RandomResizedCrop(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = SimpleClassifier().to(device)

    model_path = '/content/drive/MyDrive/trained_classifier.pth'
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)

    test_dataset = BagDataset(test_bags, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

    predictions = predict(model, test_loader)
    predictions = [int(pred[0]) for pred in predictions]

    submission = pd.DataFrame({'image_id': test_ids, 'y_pred': predictions})
    submission.to_csv('/content/drive/MyDrive/submission.csv', index=False)

    print("Submission file created: submission.csv")


KeyboardInterrupt: 