In [39]:
import torch
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
from resnet import resnet18
import time
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from utils import read_raw

In [40]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_best_path = './models/best_5x5_resnet.pth'

In [41]:
net = resnet18(num_classes=6).to(device)
net.load_state_dict(torch.load(save_best_path))
net = net.to(device)
net.eval()

ResNet(
  (conv1): Conv2d(7, 64, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReL

In [42]:
file_name = './data/newrawfile_ref.raw'
bands = [33, 34, 35, 36, 37, 38, 39]
raw = read_raw(file_name, shape=(750, 288, 384), setect_bands=bands, cut_shape=(750, 380))

In [43]:
data = []
block_size = 2
for i in range(0 + block_size, raw.shape[0] - block_size):
    for j in range(0 + block_size, raw.shape[1] - block_size):
        x = raw[i - block_size: i + block_size + 1, j - block_size: j + block_size + 1, :]
        x = x.transpose(2, 0, 1)
        data.append(x)
data = np.array(data)
data = np.asarray(data, dtype=np.float32)
data = torch.tensor(data)

In [44]:
# 将data放进网络中预测
pred = []
with torch.no_grad():
    for i in tqdm(range(data.shape[0])):
        x = data[i].unsqueeze(0).to(device)
        out = net(x)
        out = torch.argmax(out, dim=1)
        pred.append(out.item())

 10%|▉         | 27367/280496 [01:01<09:30, 443.47it/s]


KeyboardInterrupt: 

In [None]:
# 将预测结果转换为图像
pred = np.array(pred)
pred = pred.reshape((raw.shape[0], raw.shape[1]))

In [None]:
# 展示预测结果
plt.figure(figsize=(10, 10))
plt.imshow(pred)
plt.savefig('pred.png', dpi=300)