In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np
from sklearn.metrics import accuracy_score
from torchvision import transforms
from astropy.io import fits 
from skimage.transform import resize
import time

import matplotlib
%matplotlib inline
#matplotlib.use('Agg')
from matplotlib import pyplot as plt

from weighting import weighting

torch.cuda.benchmark=True

IMG_PATH = "E:/Documents/Python_Scripts/CNN/TRAINING/"


In [2]:
class CategoricalNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.feature_extractor = torch.nn.Sequential(
            
            torch.nn.Conv2d(1,64,5,padding=2), # 1 input, 32 out, filter size = 5x5, 2 block outer padding
            torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,5,padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128,256,5,padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(256),
            torch.nn.MaxPool2d(2))
 
        self.classifier = torch.nn.Sequential(
		    torch.nn.Dropout(0.25),
            torch.nn.Linear(256*16*16,256), # Fully connected layer 
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(256,10))
        
    def forward(self,x):
        features = self.feature_extractor(x)
        output = self.classifier(features.view(int(x.size()[0]),-1))
        output= torch.nn.functional.log_softmax(output,dim=1)
        return output

In [3]:
class RegressionNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128, 256, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(512, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(256),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(256, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2))
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.25),
            torch.nn.Linear(128 * 2 * 2, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(256, 10))

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features.view(int(x.size()[0]), -1))
        output= torch.nn.functional.log_softmax(output,dim=1) # Give results using softmax
        return output

In [4]:
def weight_init(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.zero_()

In [5]:
def PIL_transform(_img):
    _img[_img != _img] = 0
    _img -= _img.min()
    _img *= 255./_img.max()
    _img = _img.astype(np.uint8)
    return _img 

In [6]:
import torch.utils.data as data

IMG_EXTENSIONS = [
    ".fits"
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images


def default_fits_loader(file_name: str, img_size: tuple, slice_index):
    file = fits.open(file_name)
    _data = file[1].data
    _data = resize(_data[slice_index], img_size)
    _label = file[0].header['LABEL']
    file.close()

    if len(_data.shape) < 3:
        _data = _data.reshape((*_data.shape, 1))
    
    return _data, _label


class FITSCubeDataset(data.Dataset):
    def __init__(self, data_path, cube_length, transforms, img_size):
        self.data_path = data_path
        self.transforms = transforms
        self.img_size = img_size
        self.cube_length = cube_length
        self.img_files = make_dataset(data_path)

    def __getitem__(self, index):
        cube_index = index // self.cube_length
        slice_index = index % self.cube_length
        _img, _label = default_fits_loader(self.img_files[cube_index], self.img_size, slice_index)
        _img[_img != _img] = 0
        _img = PIL_transform(_img)
        if self.transforms is not None:
            _data = (self.transforms(_img), _label)
        #else:
        #    _data = (_img, _label)
            
        return _data

    def __len__(self):
        return len(self.img_files)*self.cube_length

In [7]:
IMG_PATH = "E:/Documents/Python_Scripts/CNN/TRAINING/"

def plot_accuracy(accuracies, val_acc, epochs, filename):
    fig = plt.figure()
    ax = fig.gca()
    ax.set_xlim(0, max(epochs))
    ax.set_ylim(0, 100)
    plt.plot(epochs, accuracies,'b',label='Training Accuracy',zorder=1)
    plt.plot(epochs, val_acc,'purple',label='Validation Accuracy',zorder=0)
    plt.ylabel('Accuracy')
    plt.xlabel('Training Epoch')
    plt.legend(loc='best',fontsize='small')
    fig.savefig(IMG_PATH+filename, bbox_inches='tight')
    plt.close()

In [8]:
def adjust_learning_rate(optimizer, epoch, initial_lr, num_epochs):
    decay = initial_lr / num_epochs
    lr = initial_lr - decay*epoch
    print("Set LR to %f" % lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [9]:
def train(model: torch.nn.Module, 
          transforms, 
          data_path="E:/Documents/Python_Scripts/CNN/TRAINING/EXAMPLES/", 
          val_path="E:/Documents/Python_Scripts/CNN/TRAINING/EXAMPLES/",  
          num_epochs=50, 
          batch_size=32, 
          verbose=True,
          cube_length=640, img_size=(64, 64), 
          loss=torch.nn.CrossEntropyLoss(), 
          lr_schedule=True, initial_lr=1e-3, suffix=""):

    data_path = os.path.abspath(data_path)
    val_path = os.path.abspath(val_path)
	
    model = model.train()
    device = torch.device("cuda")
    model = model.to(device).to(torch.float)
    start = time.time()
    
    print('Creating sampling weight array')
    train_loader = DataLoader(FITSCubeDataset(data_path, cube_length, transforms, img_size), 
                              batch_size=6399, shuffle=False)
    dataiter = iter(train_loader)
    dummy_labels = []
    for idx, (batch, target) in enumerate(tqdm(train_loader)):
        dummy_labels.append(np.array(target.numpy()))
    dummy_labels = np.hstack(dummy_labels)
    print(len(dummy_labels))
    print('Number of labels=',len(set(dummy_labels)))
    weights = weighting(dummy_labels)
    sampler = WeightedRandomSampler(weights, len(weights))
    end = time.time()
    print('Weights Created in %.2gs'%(end-start))
    #batch size was 10*640?
    start = time.time()
    val_loader = DataLoader(FITSCubeDataset(val_path, cube_length, transforms, img_size), 
                            batch_size=6399, shuffle=False)
    dataiter = iter(val_loader)
    dummy_val_labels = []
    for idx, (batch, target) in enumerate(tqdm(val_loader)):
        dummy_val_labels.append(np.array(target.numpy()))
    dummy_val_labels = np.hstack(dummy_val_labels)
    print(len(dummy_val_labels))
    print('Number of labels=',len(set(dummy_val_labels)))
    val_weights = weighting(dummy_val_labels)
    val_sampler = WeightedRandomSampler(val_weights, len(val_weights))
    end = time.time()
    print('Validation weights Created in %.2gs'%(end-start))
    
    loader = DataLoader(FITSCubeDataset(data_path, cube_length, transforms, img_size), 
                        batch_size, shuffle=False, sampler=sampler)
    validation_loader = DataLoader(FITSCubeDataset(data_path, cube_length, transforms, img_size), 
                                   batch_size, shuffle=False, sampler=val_sampler)   
    
    optim = torch.optim.Adam(model.parameters(), initial_lr)
	
    accuracies, val_accuracies, epochs = [0], [0], [0]
	
    for i in range(num_epochs):
        print("Epoch %d of %d" % (i+1, num_epochs))
        _accuracies = []
        _val_accuracies = []
        model.train(True)
        for idx, (batch, target) in enumerate(tqdm(loader)):
            batch = batch.to(device).to(torch.float)
            if isinstance(loss, torch.nn.CrossEntropyLoss):
                target = target.to(device).to(torch.long)
            else:
                target = target.to(device).to(torch.float)
            pred = model(batch)

            loss_value = loss(pred, target)

            optim.zero_grad()
            loss_value.backward()
            optim.step()

            pred_npy = pred.detach().cpu().numpy()
            target_npy = target.detach().cpu().numpy()

            if isinstance(loss, torch.nn.CrossEntropyLoss):
                pred_npy = np.argmax(pred_npy, axis=1) 
                
            ###Change the error metric here###

            pred_int = np.round(pred_npy).astype(np.uint8).reshape(-1)
            target_npy = target_npy.astype(np.uint8).reshape(-1)

            _accuracies.append(accuracy_score(target_npy, pred_int)*100)
            
        epochs.append(i+1)

        mean_accuracy = sum(_accuracies)/len(_accuracies)
        accuracies.append(mean_accuracy)

        print("Mean accuracy: %f" % mean_accuracy)
        
        model.train(False)

        for idx, (batch, target) in enumerate(tqdm(validation_loader)):
            batch = batch.to(device).to(torch.float)
            if isinstance(loss, torch.nn.CrossEntropyLoss):
                target = target.to(device).to(torch.long)
            else:
                target = target.to(device).to(torch.float)
            pred = model(batch)

            loss_value = loss(pred, target)

            pred_npy = pred.detach().cpu().numpy()
            target_npy = target.detach().cpu().numpy()

            if isinstance(loss, torch.nn.CrossEntropyLoss):
                pred_npy = np.argmax(pred_npy, axis=1) 
                
            ###Change the error metric here###

            pred_int = np.round(pred_npy).astype(np.uint8).reshape(-1)
            target_npy = target_npy.astype(np.uint8).reshape(-1)

            _val_accuracies.append(accuracy_score(target_npy, pred_int)*100)

        mean_accuracy = sum(_val_accuracies)/len(_val_accuracies)
        val_accuracies.append(mean_accuracy)
        if lr_schedule:
            plot_accuracy(accuracies,val_accuracies, epochs, "Validation_accuracy_scheduler2%s.png" % suffix)
        else:
            plot_accuracy(accuracies,val_accuracies, epochs, "Validation_accuracy_no_scheduler2%s.png" % suffix)
        print("Mean Validation accuracy: %f" % mean_accuracy)
    
        
        
        model.eval()


In [10]:
if __name__ == '__main__':
    print("Creating Model and Initializing weights")
	
#    for model_class, loss_fn, suffix in zip([CategoricalNet, RegressionNet], [torch.nn.CrossEntropyLoss(), torch.nn.MSELoss()], ["_categorical", "_regression"]):
#        for schedule in [True, False]:
            
    model_class, loss_fn, suffix = RegressionNet, torch.nn.CrossEntropyLoss(), "_categorical"
    schedule = True
    
    model = model_class()
    model.apply(weight_init)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

start = time.time()
train(model, transform, num_epochs=100, batch_size=32, lr_schedule=schedule, loss=loss_fn, suffix=suffix)
end = time.time()
print('TRAIN TIME:')
print('%.2gs'%(end-start))

Creating Model and Initializing weights
Creating sampling weight array


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  return umr_minimum(a, axis, None, out, keepdims)
  return umr_maximum(a, axis, None, out, keepdims)
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00, 10.69s/it]


6400
Number of labels= 10
Weights Created in 21s


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:20<00:00, 10.11s/it]


6400
Number of labels= 10
Validation weights Created in 20s
Epoch 1 of 100


  0%|▍                                                                                 | 1/200 [00:01<06:00,  1.81s/it]


RuntimeError: CUDA error: out of memory

In [None]:
torch.save(model, IMG_PATH+'RegressionNet.pt')

In [None]:
IMG_PATH = "E:/Documents/Python_Scripts/CNN/TRAINING/"
tester = torch.load(IMG_PATH+'RegressionNet.pt').cpu()

In [None]:
tester

In [None]:
for m in tester.modules():
    if isinstance(m, torch.nn.Conv2d):
        print(m.weight.data)

In [None]:
test = fits.open("E:/Documents/Python_Scripts/CNN/TRAINING/EXAMPLES/RefL0025N0376,28,12,0,296111.fits")

In [None]:
test.info()
print('TRUE LABEL=',test[0].header['LABEL'])
d = test[1].data[200]
dat = PIL_transform(d)
plt.figure()
plt.imshow(dat,cmap='jet')
plt.colorbar()
plt.savefig(IMG_PATH+'before')

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])

In [None]:
data = transform(dat.reshape(*dat.shape,1)).unsqueeze(0).float()
plt.figure()
plt.imshow(data[0,0,:,:],cmap='jet')
plt.colorbar()
plt.savefig(IMG_PATH+'after')
print(data.shape)

output = torch.nn.functional.softmax(tester(data),dim=1).detach().numpy()
print(output)
print('TEST LABEL= ',np.argmax(output))
test.close()

In [None]:
test2 = fits.open("E:/Documents/Python_Scripts/CNN/TRAINING/EXAMPLES/RefL0100N1504,28,1222,0,9017403.fits")
print('TRUE LABEL=',test2[0].header['LABEL'])
d = test2[1].data[30]
d = PIL_transform(d)
data = transform(d.reshape(*d.shape,1)).unsqueeze(0).float()
plt.figure()
plt.imshow(data[0,0,:,:],cmap='jet')
plt.colorbar()
plt.show()
output = torch.nn.functional.softmax(tester(data),dim=1).detach().numpy()
print(output)
print(np.argmax(output))

In [None]:
data1 = np.random.uniform(0,10000,[64,64])
d1 = PIL_transform(data1)
d1 = transform(d1.reshape(*d1.shape,1)).unsqueeze(0).float() 
output = torch.nn.functional.softmax(tester(d1),dim=1).detach().numpy()
print(output)
print(np.argmax(output))
plt.figure()
plt.imshow(d1[0,0,:,:],cmap='jet')
plt.colorbar()
plt.show()

In [None]:
data = np.random.uniform(0,10,[64,64])
d = PIL_transform(data)
d = transform(data.reshape(*data.shape,1)).unsqueeze(0).float()
plt.figure()
plt.imshow(d[0,0,:,:],cmap='jet')
plt.colorbar()
plt.savefig(IMG_PATH+'test_image',bbox_inches='tight')
output = torch.nn.functional.softmax(tester(d),dim=1).detach().numpy()
print(np.argmax(output))
print(output)