In [1]:
from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import argparse
import pandas as pd

from pathlib import Path
import os, sys

from fastai.data.all import Transform, DataBlock, RandomSplitter, L
from fastai.vision.all import *
from fastai.distributed import *
from fastai.callback.all import SaveModelCallback, EarlyStoppingCallback
import segmentation_models_pytorch as smp

# Load StarCraft2Sensor stuff
ipynb_dir = os.path.dirname(os.path.realpath("__file__"))
code_root = os.path.join(ipynb_dir, '..')
sys.path.append(code_root)  # Needed for import below

from sc2sensor.dataset import StarCraftSensor
from sc2sensor.utils.sensor_utils import SensorPlacementDataset, SUPPORTED_PLACEMENT_KINDS
from sc2sensor.utils.unit_type_data import NONNEUTRAL_CHANNEL_TO_ID, NONNEUTRAL_ID_TO_NAME
CHANNEL_TO_NAME = [NONNEUTRAL_ID_TO_NAME[NONNEUTRAL_CHANNEL_TO_ID[i]] for i in range(len(NONNEUTRAL_CHANNEL_TO_ID))]
# removing barrier_h and barrier_v from the placement kinds since diag kinds are more fitting
PLACEMENT_KINDS = [kind for kind in SUPPORTED_PLACEMENT_KINDS if not (kind.endswith('_h') or kind.endswith('_v'))]


data_root = os.path.join(code_root, 'data') # Data root directory
data_subdir = 'starcraft-sensor-dataset'

In [2]:
class FakeArgs():
    def __str__(self):
        return str(self.__dict__)

args = FakeArgs()
args.max_samples = 60000
args.batch_size = 24
args.n_epochs = 10

print(f'Starting with inputs: {str(args)}')

if args.max_samples is not None:
    print(f'Using a max number of a {args.max_samples}')

Starting with inputs: {'max_samples': 60000, 'batch_size': 24, 'n_epochs': 10}
Using a max number of a 60000


In [3]:
def save_benchmark_results(save_root, result, name, benchmark_kind,
                             sensor_kwargs=None, metric_name='MSE', raw_model_path=None):
    if benchmark_kind == 'unit_identification':
        results_file = 'unit_identification_results.csv' if sensor_kwargs is None else 'sensor_unit_identification_results.csv'
    elif benchmark_kind == 'next_window':
        results_file = 'next_window_results.csv' if sensor_kwargs is None else 'sensor_next_window_results.csv'
    else:
        raise NotImplementedError(f'No benchmark file exists for type {benchmark_kind}')
        
    save_root = Path(save_root)
    if (save_root/results_file).exists():
        current_results = pd.read_csv(save_root/results_file)
    else:
        current_results = pd.DataFrame()
        
    results_to_be_saved = {
        'model_name':name,
         metric_name: [result],
        'sensor_kind':None if sensor_kwargs is None else sensor_kwargs['kind'],
        'sensor_kwargs': [sensor_kwargs],
        'raw_model_path':str(raw_model_path) if isinstance(raw_model_path, Path) else raw_model_path}
    
    results_df = pd.concat((current_results, pd.DataFrame(results_to_be_saved)))
    results_df.to_csv(save_root/results_file, index=False)
    return True

def get_best_model(model_root, with_suffix=False):
    model_root = Path(model_root)
    train_history = pd.read_csv(model_root/'train_history.csv')
    best_epoch = train_history.valid_loss.argmin()
    best_model = [i for i in list(model_root.glob('*')) if str(i).endswith(f'_{best_epoch}.pth')]
    assert len(best_model) == 1, f'Found {len(best_model)} best model(s) instead of just 1.'
    if with_suffix:
        return best_model[0]
    else:
        return best_model[0].with_suffix('')

# Next Window Experiments

In [4]:
# Create subclass of original dataset for next window prediction
class NextWindowDataset(StarCraftSensor):
  def __init__(self, *args, max_samples=None, **kwargs):
    assert 'use_sparse' not in kwargs, 'Cannot set use_sparse with this dataset.'
    assert 'compute_labels' not in kwargs, 'Cannot set use_sparse with this dataset.'
    super().__init__(*args, use_sparse=True, compute_labels=False, **kwargs)
    self.max_samples = max_samples
    
    # Sort data so that next index is merely + 1
    self.metadata = self.metadata.sort_values(['static.replay_name', 'dynamic.window_idx']).reset_index(drop=True)
    md = self.metadata
    
    # Get starting window indices
    start_windows = md[(md['dynamic.num_windows'] > 1) 
                       & (md['dynamic.window_idx'] < (md['dynamic.num_windows'] - 1))]
    self.start_idx = start_windows.index

  def __getitem__(self, idx):
    # Get original indices of start and end based on input 
    md = self.metadata
    # Assumes sorted
    orig_idx = self.start_idx[idx]
    next_idx = orig_idx + 1
    # Only sanity check first and last as this may be expensive
    if idx == 0 or idx == len(self) - 1:
      assert md['static.replay_name'][orig_idx] == md['static.replay_name'][next_idx], 'Replays are not the same'
      assert md['dynamic.window_idx'][orig_idx] + 1 == md['dynamic.window_idx'][next_idx], 'Window indices are not adjacent'

    # Get combined hyperspectral images
    def get_hyperspectral_dense(idx):
      # Concatenate player1 and player2 hyperspectral
      replay_file, window_idx = self._get_replay_and_window_idx(idx)
      with np.load(replay_file) as data:
        player_1_hyperspectral = self._extract_hyperspectral(
          'player_1', data, window_idx)
        player_2_hyperspectral = self._extract_hyperspectral(
          'player_2', data, window_idx)
      return torch.concat([player_1_hyperspectral.to_dense(), 
                           player_2_hyperspectral.to_dense()], 
                          dim=-3).float()
    windows = [get_hyperspectral_dense(idx) for idx in [orig_idx, next_idx]] 
    x = windows[0]
    y = windows[1] - windows[0] # Compute diff
    return x, y

  def __len__(self):
    if self.max_samples is not None:
      return min(self.max_samples, len(self.start_idx))
    else:
      return len(self.start_idx)

# Create fastai dataloaders given the PyTorch dataset
class AddChannelCodes(Transform):
  "Add the code metadata to a `TensorMask`"
  def __init__(self, codes=None):
      self.codes = codes
      if codes is not None: self.vocab,self.c = codes,len(codes)

  def decodes(self, o):
      if self.codes is not None: o.codes=self.codes
      return o
    
# HACK: Put all instances in both "train" and "valid"
# From https://forums.fast.ai/t/solved-not-splitting-datablock/84759/3
def all_splitter(o): return L(int(i) for i in range(len(o))), L(int(i) for i in range(len(o)))
    
SC2_CODES = [
  f'{player}_{name}'
  for player in ['P1','P2']
  for name in CHANNEL_TO_NAME
]
def create_dataloaders_from_dataset(dataset, splitter=None, **kwargs):
    # Needs to have reference to dataset for closures
    assert 'get_x' not in kwargs
    assert 'get_y' not in kwargs
    assert 'get_items' not in kwargs
    splitter = splitter if splitter is not None else RandomSplitter(seed=0)
    assert len(SC2_CODES) == dataset[0][0].shape[0], 'Number of codes does not match number of channels'
    block = DataBlock(
        get_items=lambda d: list(range(len(d))),
        get_x=lambda idx: dataset[idx][0],
        get_y=lambda idx: dataset[idx][1],
        blocks=None, # These are just transforms
        splitter=splitter,
        item_tfms=[AddChannelCodes(SC2_CODES)],
    )
    return block.dataloaders(dataset, **kwargs)

In [5]:
# Load test set
next_window_test = NextWindowDataset(root=data_root, subdir=data_subdir,
                                     train=False, max_samples=args.max_samples)

print(f'Num Test: {len(next_window_test)}')
print(f'x.shape: {next_window_test[-1][0].shape}, y.shape: {next_window_test[-1][1].shape}')

Using cached CSV metadata
Not computing labels
Post-processing metadata
Finished dataset init
Num Test: 60000
x.shape: torch.Size([340, 64, 64]), y.shape: torch.Size([340, 64, 64])


In [7]:
sensor_kwargs = {'n_sensors': 50,
                 'radius': 5.5,
                 'kind': 'grid',
                 'failure_rate': 0.2}
sensor_next_window_test = SensorPlacementDataset(next_window_test, **sensor_kwargs,
                                                 return_mask=False, make_cache=False,
                                                 noiseless_ground_truth=True)

plt.imshow(sensor_next_window_test[0][0].sum(0))

# Unit Type Identification Experiments

In [None]:
from sc2sensor.utils.unit_type_data import NONNEUTRAL_CHANNEL_TO_ID, NONNEUTRAL_ID_TO_NAME
CHANNEL_TO_NAME = [NONNEUTRAL_ID_TO_NAME[NONNEUTRAL_CHANNEL_TO_ID[i]] for i in range(len(NONNEUTRAL_CHANNEL_TO_ID))]

In [None]:
# Loading dataset
class SegmentationDataset(torch.utils.data.Dataset):
    
    def __init__(self, segment_path, create_metadata=True):
        super().__init__()
        self.path = Path(segment_path)
        self.X_filenames = self._get_files()
        if create_metadata:
            self.metadata, self.match_metadata = self._make_metadata()
        
    def __len__(self):
        return len(self.X_filenames)
    
    def __getitem__(self, idx):
        X_filename = self.X_filenames[idx]
        y_filename = os.path.splitext(X_filename)[0].replace('images','labels') + '_labels.png'
        
#         return torch_read_image(str(X_filename)), torch_read_image(str(y_filename)).squeeze()
        return (plt.imread(str(X_filename))*255).astype(np.uint8), \
               (plt.imread(str(y_filename))*255).astype(np.uint8)

    def _get_files(self):
        files =  list((self.path / 'images').glob('**/*.png'))
        assert len(files) > 0, f'No .png files found in {self.path}'
        return files
    
    def _make_metadata(self):
        replay_names = [str(f).split('_')[-2].split('/')[-1] for f in self.X_filenames]
        metadata = pd.DataFrame({'static.replay_name':replay_names})
        match_metadata = metadata.drop_duplicates(subset=['static.replay_name']).reset_index(drop=True)
        return metadata, match_metadata

In [None]:
segmentation_data_path = Path(data_root)/'starcraft-sensor-dataset'/'segment'/'test'

In [None]:
def test_unit_identification_model(name, batch_size, segment_path, path_to_model):
    
    FASTAI_ARCHES = dict(
      unet_resnet18=resnet18,
      unet_resnet34=resnet34,
      unet_xresnet18=xresnet18_deep,
      unet_xresnet34=xresnet34_deep,
       unet_squeezenet1_0=squeezenet1_0,
      unet_squeezenet1_1=squeezenet1_1,
      unet_densenet121=densenet121,
      unet_densenet169=densenet169,
    )
    
    sc2_segment = DataBlock(
                      blocks=(ImageBlock, MaskBlock(codes=CHANNEL_TO_NAME)),
                      get_items=get_image_files,
                      get_y=lambda filename: (os.path.splitext(filename)[0].replace('images','labels') + '_labels.png'),
                      splitter=RandomSplitter(seed=0),
                      batch_tfms=None)
    dls = sc2_segment.dataloaders(segment_path/'images', shuffle=True, bs=batch_size, num_workers=12)

    # Create learner
    learner = unet_learner(arch=FASTAI_ARCHES[name], dls=dls)
    learner = learner.load(path_to_model)    
    
    result = learner.validate()
    # Now that testing has finished, empty the cache
    torch.cuda.empty_cache()
    return result[0]

In [None]:
print(f'Starting result aggregation on segmentation dataset: {str(segmentation_data_path)}. ')

unit_identification_model_root = extracted_model_root / 'unit_identification'

unit_identification_model_dirs = [f for f in unit_identification_model_root.glob('**/*') 
                                    if f.is_dir() and not str(f).startswith('unet_')]

for model_dir in unit_identification_model_dirs:
    name = model_dir.name
    arch = name.split('unet_')[-1]
    
    print('\n'*5, f'Starting {name}.'.center(60, '-'))
    
    path_to_model = get_best_model(model_dir)
    
    print(f'Using model {path_to_model}')

    result = test_unit_identification_model(name, args.batch_size, segmentation_data_path, path_to_model)
    
    
    save_benchmark_results(benchmark_results_save_root, result, name, benchmark_kind='unit_identification',
                             sensor_kwargs=None, metric_name='MSE', raw_model_path=path_to_model)
    
    print(f'Finished {name}.')

In [None]:
pd.read_csv(benchmark_results_save_root/'unit_identification_results.csv').sort_values('MSE')

# Getting *sensor* unit type identification results

In [None]:
segmentation_dataset = SegmentationDataset(segmentation_data_path,
                                           create_metadata=True)

In [None]:
def test_sensor_unit_identification_model(name, batch_size, segmentation_dataset,
                                          path_to_model, sensor_kwargs):
    
    FASTAI_ARCHES = dict(
      unet_resnet18=resnet18,
      unet_resnet34=resnet34,
      unet_xresnet18=xresnet18_deep,
      unet_xresnet34=xresnet34_deep,
       unet_squeezenet1_0=squeezenet1_0,
      unet_squeezenet1_1=squeezenet1_1,
      unet_densenet121=densenet121,
      unet_densenet169=densenet169,
    )
    
    segmentation_sensor_placement_dataset = SensorPlacementDataset(
            segmentation_dataset, **sensor_kwargs,
            return_mask=False, make_cache=False, noiseless_ground_truth=True)
    
    block = DataBlock(
        blocks=(ImageBlock, MaskBlock),
        get_items=lambda d: list(range(len(d))),
        get_x=lambda idx: segmentation_sensor_placement_dataset[idx][0],
        get_y=lambda idx: segmentation_sensor_placement_dataset[idx][1],
        splitter=RandomSplitter(seed=0),
    )
    
    dls = block.dataloaders(segmentation_sensor_placement_dataset, batch_size=batch_size)

    # Create learner
    learner = unet_learner(arch=FASTAI_ARCHES[name], dls=dls, n_out=170)
    learner = learner.load(path_to_model)    
    
    result = learner.validate()
    # Now that testing has finished, empty the cache
    torch.cuda.empty_cache()
    return result[0]

In [None]:
print(f'Starting **sensor** result aggregation on segmentation dataset: {str(segmentation_data_path)}. ')

sensor_unit_identification_model_root = extracted_model_root / 'sensor_unit_identification'

sensor_unit_identification_model_dirs = [f for f in sensor_unit_identification_model_root.glob('**/*') 
                                    if f.is_dir() and not str(f).endswith('_sensors')]

for model_dir in sensor_unit_identification_model_dirs:
    sensor_kind = model_dir.parent.name.split('_sensors')[0]
    name = model_dir.name
    arch = name.split('unet_')[-1]
    
    sensor_kwargs = {'n_sensors': 50, 'radius': 5.5, 'kind': sensor_kind, 'failure_rate': 0.2}
    
    print('\n'*5, f'Starting {name} with {sensor_kind} placement.'.center(60, '-'))
    
    path_to_model = get_best_model(model_dir)
    
    print(f'Using model {path_to_model}')

    result = test_sensor_unit_identification_model(name, args.batch_size,
                                                   segmentation_dataset, path_to_model, sensor_kwargs)
    
    
    save_benchmark_results(benchmark_results_save_root, result, name, benchmark_kind='unit_identification',
                             sensor_kwargs=sensor_kwargs, metric_name='MSE', raw_model_path=path_to_model)
    
    print(f'Finished {name} with {sensor_kind} placement.')

In [None]:
pd.read_csv(benchmark_results_save_root/'sensor_unit_identification_results.csv')