In [2]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
from models.resnet_simclr import ResNetSimCLR
import sys
sys.path.append('../')
from pyn import Json
from PIL import Image
from tqdm import tqdm, trange
import numpy as np
import os
join = os.path.join

In [3]:
PIL2Tesor = transforms.ToTensor()
def img2tensor(img):
    
    img = Image.open(img)
    return PIL2Tesor(img)

class RockData(Dataset):

    def __init__(self, root, image_paths, labels, n_views, merge_label=True):
        self.root = root
        self.image_paths = image_paths
        self.labels = labels
        self.transform = img2tensor
        self.merge_label = merge_label
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[image_path]
        # 14标签合为3个
        if self.merge_label:
            if label >= 0 and label <= 4: label = 0
            elif label > 4 and label <= 7: label = 1
            elif label > 7 and label <= 14: label = 2
        # 将图像缩放到224中心
        image_input = self.transform(join(self.root, image_path))
        label = torch.tensor(label).long()
        return image_input, label
    
    def __len__(self):
        return len(self.image_paths)

In [4]:
root = '../YJY_Rock'
for ratio in [1000,2000,3000,4000,5000]:
    permute, labels = Json.load('./config/BoxImg_224_permute_0.json'),\
                        Json.load('./config/labels.json')
    tr_set, dev_set = permute['train_data'], permute['test_data']
    # tr_set = tr_set[:int(len(tr_set)*(ratio/100))]
    tr_set = tr_set[:ratio]
    checkpoint = torch.load(f'./runs/Res18-256-{ratio}/checkpoint_0200.pth.tar', map_location="cuda:0")
    model = ResNetSimCLR(checkpoint['arch'], 128)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.backbone.cuda()
    model.fc = nn.Identity()
    tr_features = []
    tr_label = []
    train_dataset = RockData(root, tr_set, labels, 2)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=False, drop_last=False)
    dev_dataset = RockData(root, dev_set, labels, 2)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=False, drop_last=False)
    with torch.no_grad():
        model.eval()
        for X, y in tqdm(train_loader):
            X = X.cuda()
            feature = model(X).cpu()
            torch.cuda.empty_cache()
            tr_features.append(feature)
            tr_label.append(y)
    _tr_features = np.concatenate(tr_features, axis=0)
    _tr_label = np.concatenate(tr_label, axis=0)
    np.save(f'../YJY_Rock/Features_SimCLR/data/tr_res18_256_{ratio}_X.npy', _tr_features)
    np.save(f'../YJY_Rock/Features_SimCLR/data/tr_res18_256_{ratio}_y.npy', _tr_label)

    dev_features = []
    dev_label = []
    with torch.no_grad():
        model.eval()
        for X, y in tqdm(dev_loader):
            X = X.cuda()
            feature = model(X).cpu()
            torch.cuda.empty_cache()
            dev_features.append(feature)
            dev_label.append(y)
    _dev_features = np.concatenate(dev_features, axis=0)
    _dev_label = np.concatenate(dev_label, axis=0)
    np.save(f'../YJY_Rock/Features_SimCLR/data/tt_res18_256_{ratio}_X.npy', _dev_features)
    np.save(f'../YJY_Rock/Features_SimCLR/data/tt_res18_256_{ratio}_y.npy', _dev_label)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 1000/1000 [00:11<00:00, 84.30it/s]
100%|██████████| 12600/12600 [02:35<00:00, 81.03it/s] 
100%|██████████| 2000/2000 [00:22<00:00, 90.42it/s] 
100%|██████████| 12600/12600 [02:12<00:00, 94.98it/s] 
100%|██████████| 3000/3000 [00:30<00:00, 99.39it/s] 
100%|██████████| 12600/12600 [02:06<00:00, 99.81it/s] 
100%|██████████| 4000/4000 [00:40<00:00, 98.67it/s] 
100%|██████████| 12600/12600 [02:06<00:00, 99.41it/s] 
100%|██████████| 5000/5000 [00:52<00:00, 94.70it/s] 
100%|██████████| 12600/12600 [02:09<00:00, 97.35it/s] 


In [6]:
from collections import Counter
tr_X, tr_y = np.load('../YJY_Rock/Features_SimCLR/data/tr_res18_256_X.npy'),\
    np.load('../YJY_Rock/Features_SimCLR/data/tr_res18_256_y.npy')
tt_X, tt_y = np.load('../YJY_Rock/Features_SimCLR/data/tt_res18_256_X.npy'),\
    np.load('../YJY_Rock/Features_SimCLR/data/tt_res18_256_y.npy')
print(Counter(tt_y))

Counter({2: 5188, 1: 4683, 0: 2729})
