In [1]:
import os
import numpy as np
from PIL import Image
import glob
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder

In [2]:
myseed = 666
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
# torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(myseed)
    torch.cuda.manual_seed_all(myseed)

In [3]:
test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [4]:
class TestingDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        """
        data
        ├── test
        |   ├── xxxxx.jpg
        |   ├── ...
        |   └── yyyyy.jpg
        """
        self.img_dir = img_dir
        self.transform = transform
        self.images = []
        self.names = []

        self.images = sorted(glob.glob(f"{self.img_dir}/*"))
        self.names = [os.path.basename(image)[:-4] for image in self.images]

    def __len__(self):
        return len(self.images)
    
    def __getnames__(self):
        return self.names
    
    def __getitem__(self, idx):
        image = self.transform(Image.open(self.images[idx]))
        return image

In [5]:
test_set = TestingDataset("../data/test", test_tfm)

In [6]:
class Resnet(nn.Module):
    def __init__(self, num_classes=200):
        super(Resnet, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        return x

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 32

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

model = Resnet().to(device)
model.load_state_dict(torch.load("resnet50_.pth"))



<All keys matched successfully>

In [8]:
threshold = 0.97

pseudo_data, pseudo_label = torch.LongTensor([]).cuda(), torch.LongTensor([]).cuda()

softmax = nn.Softmax(dim=1)
model.eval()

for data in tqdm(test_loader, desc="Pseudo Labeling"):
    with torch.no_grad():
        data = data.cuda()
        c = model(data)
        c = softmax(c)
        c, label = torch.max(c, 1)
        mask = c > threshold
        pseudo_data = torch.cat([pseudo_data, data[mask]], dim=0)
        pseudo_label = torch.cat([pseudo_label, label[mask]], dim=0)

print("\nPseudo-labeling finished, %d samples generated." % len(pseudo_data))

# save 
np.save("pseudo_data.npy", pseudo_data.cpu().numpy())
np.save("pseudo_label.npy", pseudo_label.cpu().numpy())
print("Pseudo-label saved.")

Pseudo Labeling: 100%|██████████| 63/63 [00:24<00:00,  2.61it/s]



Pseudo-labeling finished, 859 samples generated.
Pseudo-label saved.
