In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np
from sklearn.metrics import accuracy_score
from torchvision import transforms
from astropy.io import fits 
from skimage.transform import resize
import time

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from weighting import weighting

In [None]:
class CategoricalNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.feature_extractor = torch.nn.Sequential(
            
            torch.nn.Conv2d(1,64,5,padding=2), # 1 input, 32 out, filter size = 5x5, 2 block outer padding
            torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,5,padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128,256,5,padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(256),
            torch.nn.MaxPool2d(2))
 
        self.classifier = torch.nn.Sequential(
		    torch.nn.Dropout(0.25),
            torch.nn.Linear(256*16*16,256), # Fully connected layer 
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(256,9))
        
    def forward(self,x):
        features = self.feature_extractor(x)
#        output = self.classifier(features.view(int(x.size()[0]),-1))
        output = self.classifier(features.view(int(x.size()[0]),-1)) # Give results using softmax
        return output

In [None]:
class RegressionNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128, 256, 5, padding=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(512, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(256),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(256, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(128),
            torch.nn.MaxPool2d(2))
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.25),
            torch.nn.Linear(128 * 2 * 2, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(256, 1))

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features.view(int(x.size()[0]), -1))
        # output= F.log_softmax(output,dim=1) # Give results using softmax
        return output

In [None]:
def weight_init(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.zero_()

In [None]:
import torch.utils.data as data

IMG_EXTENSIONS = [
    ".fits"
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images


def default_fits_loader(file_name: str, img_size: tuple, slice_index):
    file = fits.open(file_name)
    _data = file[1].data
    _data = resize(_data[slice_index], img_size)
    _label = file[0].header['LABEL']

    if len(_data.shape) < 3:
        _data = _data.reshape((*_data.shape, 1))
    
    return _data, _label


class FITSCubeDataset(data.Dataset):
    def __init__(self, data_path, cube_length, transforms, img_size):
        self.data_path = data_path
        self.transforms = transforms
        self.img_size = img_size
        self.cube_length = cube_length
        self.img_files = make_dataset(data_path)

    def __getitem__(self, index):
        cube_index = index // self.cube_length
        slice_index = index % self.cube_length
        _img, _label = default_fits_loader(self.img_files[cube_index], self.img_size, slice_index)
        _img[_img != _img] = 0
        if self.transforms is not None:
            _data = (self.transforms(_img), _label)
        else:
            _data = (_img, _label)
            
        return _data

    def __len__(self):
        return len(self.img_files)*self.cube_length

In [None]:
IMG_PATH = "E:/Documents/Python_Scripts/CNN/TRAINING/"

def plot_accuracy(accuracies, val_acc, epochs, filename):
    fig = plt.figure()
    ax = fig.gca()
    ax.set_xlim(0, max(epochs))
    ax.set_ylim(0, 100)
    plt.plot(epochs, accuracies,'g',label='Training Accuracy')
    plt.plot(epochs, val_acc,'purple',label='Validation Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Training Epoch')
    plt.legend(loc='best',fontsize='small')
    fig.savefig(IMG_PATH+filename, bbox_inches='tight')
    plt.close()

In [None]:
def adjust_learning_rate(optimizer, epoch, initial_lr, num_epochs):
    decay = initial_lr / num_epochs
    lr = initial_lr - decay*epoch
    print("Set LR to %f" % lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
def train(model: torch.nn.Module, transforms, data_path="E:/Documents/Python_Scripts/CNN/TRAINING/EXAMPLES/", num_epochs=50, batch_size=32, verbose=True,
          cube_length=640, img_size=(64, 64), loss=torch.nn.MSELoss(), lr_schedule=True, initial_lr=1e-3, suffix=""):

    data_path = os.path.abspath(data_path)
	
    model = model.train()
    device = torch.device("cuda")
    model = model.to(device).to(torch.float)
    start = time.time()
    print('Creating sampling weight array')
    train_loader = DataLoader(FITSCubeDataset(data_path, cube_length, transforms, img_size), batch_size=640*10, shuffle=False)
    dataiter = iter(train_loader)
    images, labels = dataiter.next()
    weights = weighting(labels)
    sampler = WeightedRandomSampler(weights, len(weights))
    end = time.time()
    print('Weights Created in %.2gs'%(end-start))
    loader = DataLoader(FITSCubeDataset(data_path, cube_length, transforms, img_size), batch_size, shuffle=False, sampler=sampler)
    optim = torch.optim.Adam(model.parameters(), initial_lr)
	
    accuracies, val_accuracies, epochs = [0], [0], [0]
	
    for i in range(num_epochs):
        print("Epoch %d of %d" % (i+1, num_epochs))
        _accuracies = []
        _val_accuracies = []
        model.train()
        for idx, (batch, target) in enumerate(tqdm(loader)):
            batch = batch.to(device).to(torch.float)
            if isinstance(loss, torch.nn.CrossEntropyLoss):
                target = target.to(device).to(torch.long)
            else:
                target = target.to(device).to(torch.float)
            pred = model(batch)

            loss_value = loss(pred, target)

            optim.zero_grad()
            loss_value.backward()
            optim.step()

            pred_npy = pred.detach().cpu().numpy()
            target_npy = target.detach().cpu().numpy()

            if isinstance(loss, torch.nn.CrossEntropyLoss):
                pred_npy = np.argmax(pred_npy, axis=1) 

            pred_int = np.round(pred_npy).astype(np.uint8).reshape(-1)
            target_npy = target_npy.astype(np.uint8).reshape(-1)

            _accuracies.append(accuracy_score(target_npy, pred_int)*100)
            
        epochs.append(i+1)

        mean_accuracy = sum(_accuracies)/len(_accuracies)
        accuracies.append(mean_accuracy)

        print("Mean accuracy: %f" % mean_accuracy)
        
        model.eval()

        for idx, (batch, target) in enumerate(tqdm(loader)):
            batch = batch.to(device).to(torch.float)
            if isinstance(loss, torch.nn.CrossEntropyLoss):
                target = target.to(device).to(torch.long)
            else:
                target = target.to(device).to(torch.float)
            pred = model(batch)

            loss_value = loss(pred, target)

            pred_npy = pred.detach().cpu().numpy()
            target_npy = target.detach().cpu().numpy()

            if isinstance(loss, torch.nn.CrossEntropyLoss):
                pred_npy = np.argmax(pred_npy, axis=1) 

            pred_int = np.round(pred_npy).astype(np.uint8).reshape(-1)
            target_npy = target_npy.astype(np.uint8).reshape(-1)

            _val_accuracies.append(accuracy_score(target_npy, pred_int)*100)

        mean_accuracy = sum(_val_accuracies)/len(_val_accuracies)
        val_accuracies.append(mean_accuracy)
        if lr_schedule:
            plot_accuracy(accuracies,val_accuracies, epochs, "Validation_accuracy_scheduler4%s.png" % suffix)
        else:
            plot_accuracy(accuracies,val_accuracies, epochs, "Validation_accuracy_no_scheduler%s.png" % suffix)
        print("Mean Validation accuracy: %f" % mean_accuracy)

        
        
        
    model.eval()


In [None]:
if __name__ == '__main__':
    print("Creating Model and Initializing weights")
	
#    for model_class, loss_fn, suffix in zip([CategoricalNet, RegressionNet], [torch.nn.CrossEntropyLoss(), torch.nn.MSELoss()], ["_categorical", "_regression"]):
#        for schedule in [True, False]:
            
    model_class, loss_fn, suffix = CategoricalNet, torch.nn.CrossEntropyLoss(), "_categorical"
    schedule = True
    
    model = model_class()
    model.apply(weight_init)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0], [1])])

start = time.time()
train(model, transform, num_epochs=150, batch_size=64, lr_schedule=schedule, loss=loss_fn, suffix=suffix)
end = time.time()
print('TRAIN TIME:')
print('%.2gs'%(end-start))