In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
BASE = Path().resolve().parent
print(f'BASE: {BASE}')

import torch
torch.manual_seed(0)

IS_MPS_AVAILABLE = torch.backends.mps.is_available() and torch.backends.mps.is_built()
print(f'Is mps avaliable? : {IS_MPS_AVAILABLE}')
if not IS_MPS_AVAILABLE:
    IS_CUDA_AVAILABLE = torch.cuda.is_available()
    print(f'Is cuda avaliable? : {IS_CUDA_AVAILABLE}')
    print(f'cuda device count: {torch.cuda.device_count()}')
    print(f'cuda current device: {torch.cuda.current_device()}')
    print(f'cuda device name: {torch.cuda.get_device_name()}')
    # switch to cuda if available, else mps, else cpu
    DEVICE = torch.device('cuda' if IS_CUDA_AVAILABLE else 'cpu')
    print(f'device: {DEVICE}')

import h5py
from tqdm import tqdm

from sklearn.model_selection import train_test_split


BASE: C:\Users\zheng\Documents\GitHub\SDSSGalCat
Is mps avaliable? : False
Is cuda avaliable? : True
cuda device count: 1
cuda current device: 0
cuda device name: NVIDIA GeForce RTX 4060 Ti
device: cuda


In [2]:
# try resnet18
# ref: https://pytorch.org/hub/pytorch_vision_resnet/
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to C:\Users\zheng/.cache\torch\hub\v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\zheng/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 9.91MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
# load data
# data ref: https://astronn.readthedocs.io/en/stable/galaxy10sdss.html

with h5py.File(BASE / 'data/Galaxy10.h5', 'r') as f:
    images = f['images'][:]
    labels = f['ans'][:]

print(f'images.shape: {images.shape}, labels.shape: {labels.shape}')
assert images.shape[0] == labels.shape[0]

images.shape: (21785, 69, 69, 3), labels.shape: (21785,)


In [7]:
from torchvision import transforms
from PIL import Image
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [15]:
img = Image.fromarray(images[0], mode='RGB')
input_tensor = preprocess(img)
print(input_tensor.shape)
input_batch = input_tensor.unsqueeze(0)
print(input_batch.shape)

torch.Size([3, 224, 224])
torch.Size([1, 3, 224, 224])


In [18]:
# try to run the model
with torch.no_grad():
    output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

tensor([1.9724e-03, 1.6914e-03, 6.3476e-05, 8.6612e-05, 1.1132e-04, 4.7789e-03,
        1.4083e-05, 1.0921e-05, 5.7664e-05, 1.6972e-04, 2.1852e-04, 1.4466e-05,
        5.3572e-06, 7.1628e-06, 6.4568e-06, 1.3169e-05, 3.3717e-06, 3.1598e-06,
        1.5244e-05, 2.3823e-06, 4.6630e-05, 1.5864e-05, 1.5739e-04, 8.5525e-05,
        9.1811e-05, 3.5262e-05, 1.4006e-04, 5.1545e-05, 2.2844e-04, 2.5061e-03,
        2.1433e-05, 9.4596e-05, 1.8174e-04, 3.8943e-04, 8.6953e-04, 1.2722e-04,
        1.4212e-04, 1.3122e-05, 2.5656e-03, 1.7913e-04, 4.0059e-05, 1.4075e-04,
        8.6553e-05, 4.2352e-03, 1.4009e-04, 5.7675e-04, 2.0879e-05, 2.7877e-04,
        1.3059e-04, 7.4135e-05, 6.1240e-05, 1.1688e-03, 4.4015e-05, 2.8316e-05,
        2.5375e-05, 2.3510e-05, 2.6387e-05, 4.8829e-06, 1.3947e-05, 1.1758e-03,
        2.2857e-04, 1.8477e-05, 6.4050e-04, 9.2615e-04, 3.8121e-04, 1.2277e-03,
        4.5590e-04, 1.8334e-05, 5.3920e-04, 9.9821e-06, 1.2219e-05, 2.0904e-04,
        2.3430e-05, 5.4578e-03, 1.1523e-

In [23]:
# try on larger batch
img_lst = [Image.fromarray(images[i], mode='RGB') for i in range(10)]
input_tensor_lst = [preprocess(img) for img in img_lst]
input_batch = torch.cat([tsr.unsqueeze(0) for tsr in input_tensor_lst])

with torch.no_grad():
    output = model(input_batch)
probabilities = torch.nn.functional.softmax(output, dim=1)
print(probabilities.shape, input_batch.shape)

torch.Size([10, 1000]) torch.Size([10, 3, 224, 224])
