In [1]:
import torch
import torchvision
import os
import random

import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pandas as pd
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchensemble import VotingClassifier, FusionClassifier, BaggingClassifier

from transformers import BeitForImageClassification, BeitConfig, BeitFeatureExtractor

from PIL import Image
from AutoAugment.autoaugment import ImageNetPolicy

from tqdm import tqdm
from collections import defaultdict, OrderedDict

from torchensemble.utils import io

In [2]:
# Check for GPU
device = ("cuda:0" if torch.cuda.is_available() else "cpu")

device

'cuda:0'

In [3]:
df = pd.read_csv("../data/train_labels.csv", names=['img_name', 'label'], header=1)

df = df.append([df[df["label"] == 20]] * 15, ignore_index = True)
df = df.sample(frac=1, random_state = 42).reset_index(drop=True)

df.shape

(31016, 2)

In [4]:
class food_set(Dataset):

    def __init__(self, df, labels_file, img_dir, extractor, transform = None, settype = "train"):
        self.df = df
        if settype == "train":
            self.img_labels = df[:30000]
        elif settype == "val":
            self.img_labels = df[30000:]
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform
        
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
                
        features = self.feature_extractor(images=image)["pixel_values"][0]      
        
        label = self.img_labels.iloc[idx, 1]       
            
        return features, label

In [5]:
class food_test(Dataset):

    def __init__(self, img_dir, extractor, transform = None):
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.img_dir))

    def __getitem__(self, idx):
        file_name = os.listdir(self.img_dir)[idx]
        img_path = os.path.join(self.img_dir, file_name)
        image = Image.open(img_path)        
        
        if self.transform:
            image = self.transform(image)
        
        features = self.feature_extractor(images=image)["pixel_values"][0]
        
        return file_name, features

In [6]:
# Load BEiT
beit = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")

for param in beit.parameters():
    param.requires_grad = False

classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(768, 512)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(512, 81)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))

beit.classifier = classifier

beit.load_state_dict(torch.load("../../beit_2_fc_imgnetpol_29k_30_epochs.pth"))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>

In [7]:
class beit_2(nn.Module):

    def __init__(self, beit):
        super(beit_2, self).__init__()
        self.beit = beit

    def forward(self, img):
        x = self.beit(img)["logits"]
        return x

In [8]:
ensemble = VotingClassifier(
    estimator=beit_2(beit),
    n_estimators=10,
    cuda=True,
)

ensemble.to(device);

In [9]:
criterion = nn.CrossEntropyLoss()
ensemble.set_criterion(criterion)

ensemble.set_optimizer('Adam',
                    lr=0.0002,
                    betas = [0.9, 0.999])

In [13]:
img_dir = "../data2/train_set_rmbg/train_set_rmbg"
labels = "../data/train_labels.csv"

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       ImageNetPolicy(),
                                       transforms.RandomHorizontalFlip()])

test_transforms = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224)])

feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')

In [14]:
trainset = food_set(df, labels, img_dir, feature_extractor, transform = train_transforms, settype = "train")
valset = food_set(df, labels, img_dir, feature_extractor, transform = test_transforms, settype = "val")

# Create a DataLoader with the data
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=0)
valloader = DataLoader(valset, batch_size=8, shuffle=True, num_workers=0)

len(trainset)

30000

In [None]:
ensemble.fit(train_loader=trainloader, 
             epochs=1)

In [14]:
# io.load(ensemble_2, "../../models/ensemble_10_beits_10_epochs/")

In [16]:
accuracy = ensemble.evaluate(valloader)

In [None]:
accuracy

In [18]:
test_set = food_test("../data2/test_set_rmbg/test_set_rmbg", feature_extractor)
testloader = DataLoader(test_set, batch_size = 16)

len(test_set)

7653

In [20]:
test_results = defaultdict(list)

for file_name, image in tqdm(testloader):
    X = image.to(device)
    pred = ensemble(X) #["logits"] # .argmax(1)
    
    for i, img in enumerate(file_name):
        test_results["img"].append(img)        
        for rank, (label, prob) in enumerate(zip(pred.topk(5).values[i], pred.topk(5).indices[i])):
            test_results[rank + 1].append(int(prob))

100%|████████████████████████████████████████| 479/479 [09:23<00:00,  1.18s/it]


In [21]:
test_df = pd.DataFrame(test_results)

test_df.head()

Unnamed: 0,img,1,2,3,4,5
0,test_1.jpg,63,15,61,80,3
1,test_10.jpg,45,49,29,2,24
2,test_100.jpg,34,12,48,71,22
3,test_1000.jpg,15,29,10,37,23
4,test_1001.jpg,18,10,37,29,32


In [22]:
# test_df = test_df[["img", 1]].rename({"img": "img_name", 1 : "label"}, axis = 1)

# test_df.head()

In [23]:
test_df.to_csv("../../submissions/voting_ensemble.csv", index=False)