In [72]:
#@title connect google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [73]:
DATASETS_DIR_PATH = '/content/drive/MyDrive/Colab Notebooks/project/wireworldYolo/datasets' #@param {type:"string"}

In [74]:
%%capture
!pip install pyyaml==5.4.1

In [75]:
#@title import package
import os
import math
import time
import random
import json
import yaml
import numpy as np
import cv2
from google.colab.patches import cv2_imshow

In [76]:
cellColor = {
	0: [7, 193, 255], # y
	1: [243, 150, 33], # b
	2: [54, 67, 244] # r
}

In [84]:
def computeEdge(map):
  rect = [False, False, False, False]
  for cell in map:
    pos = [int(n) for n in cell.split(',')]
    if pos[0] < rect[0] or rect[0] == False: rect[0] = pos[0]
    if pos[0] > rect[2] or rect[2] == False: rect[2] = pos[0]
    if pos[1] < rect[1] or rect[1] == False: rect[1] = pos[1]
    if pos[1] > rect[3] or rect[3] == False: rect[3] = pos[1]
  return(rect)
def getRelativeNameDataPairsInRect(nameData, rect):
  relativeNameDataPairs = []
  for nameRange in nameData:
    namePos = [int(n) for n in nameRange.split(',')]
    # if not ((namePos[0] < rect[0] and namePos[2] < rect[0]) or (namePos[0] > rect[2] and namePos[2] > rect[2])) or ((namePos[1] < rect[1] and namePos[3] < rect[1]) or (namePos[1] > rect[3] and namePos[3] > rect[3])):
    if not (namePos[0] >= rect[2] or namePos[2] <= rect[0] or namePos[1] >= rect[3] or namePos[3] <= rect[1]):
      if namePos[0] < rect[0]: namePos[0] = rect[0]
      if namePos[1] < rect[1]: namePos[1] = rect[1]
      if namePos[2] > rect[2]: namePos[2] = rect[2]
      if namePos[3] > rect[3]: namePos[3] = rect[3]
      namePos[0] -= rect[0]
      namePos[1] -= rect[1]
      namePos[2] -= rect[0]
      namePos[3] -= rect[1]
      print(namePos, '<', nameRange)
      relativeNameDataPairs.append([namePos, nameData[nameRange]])
  return(relativeNameDataPairs)
def generateId():
  return(f'{time.time()}-{hex(random.randint(0, 16))}')
def labelRectRot90(xc, yc, w, h, times):
  return([
		lambda xc, yc, w, h, m: [xc, yc, w, h], 
		lambda xc, yc, w, h, m: [yc, m-xc, h, w], 
		lambda xc, yc, w, h, m: [m-xc, m-yc, w, h], 
		lambda xc, yc, w, h, m: [m-yc, xc, h, w]
	][times%4](xc, yc, w, h, 1))

In [None]:
datasetConfigs = []
datasetDirs = [p for p in os.listdir(DATASETS_DIR_PATH) if p != 'tempJson']
for datasetDirName in datasetDirs:
  datasetDirPath = os.path.join(DATASETS_DIR_PATH, datasetDirName)
  ymlPath = os.path.join(datasetDirPath, f'{datasetDirName}.yml')
  print(ymlPath)
  if not os.path.isfile(ymlPath):
    raise ValueError(f'dataset config "{ymlPath}" not found! (Please confirm whether the config file is under the "dataset/{datasetDirName}" folder.)')
    continue
  yamlContent = ''
  with open(ymlPath, 'r', encoding='utf-8') as f:
    yamlContent = yaml.load(f, Loader=yaml.Loader)
  datasetConfigs.append(yamlContent)

tempJsonPath = os.path.join(DATASETS_DIR_PATH, 'tempJson')
if not os.path.isdir(tempJsonPath):
  raise ValueError('Directory "tempJson" not found! (Please confirm whether the "tempJson" folder is under the "dataset" folder.)')
  exit()
gameMapFiles = [p for p in os.listdir(tempJsonPath) if p.endswith('.json')]

for gameMapPath in gameMapFiles:
  gameMapPath = os.path.join(tempJsonPath, gameMapPath)
  gameMapFile = open(gameMapPath, 'r', encoding='utf-8')
  gameMapContent = gameMapFile.read()
  gameMapFile.close()
  gameMapContent = json.loads(gameMapContent)
  edge = computeEdge(gameMapContent['map'])
  wholeImage = np.zeros([edge[3]-edge[1]+1, edge[2]-edge[0]+1, 3], dtype=np.int64)
  for cell in gameMapContent['map']:
    pos = [int(n) for n in cell.split(',')]
    pos[0] -= edge[0]
    pos[1] -= edge[1]
    wholeImage[pos[1], pos[0]] = np.array(cellColor[gameMapContent['map'][cell]])
  # cv2_imshow(wholeImage)
  for config in datasetConfigs:
    trainPath = os.path.join(DATASETS_DIR_PATH, config['train'])
    validPath = os.path.join(DATASETS_DIR_PATH, config['val'])
    config['names'] = list(config['names'].values())
    sliceWidth = config.get('sliceWidth', 128)
    for r in range(math.ceil(wholeImage.shape[0]/sliceWidth)):
      for c in range(math.ceil(wholeImage.shape[1]/sliceWidth)):
        sliceId = generateId()
        sliceRect = [c*sliceWidth, r*sliceWidth, (c+1)*sliceWidth, (r+1)*sliceWidth]

        relativeNameDataPairs = getRelativeNameDataPairsInRect(gameMapContent['nameData'], sliceRect)
        labelContentRows = []
        for nameDataPair in relativeNameDataPairs:
          if nameDataPair[1] in config['names']:
            row = [
              config['names'].index(nameDataPair[1]), 
              ((nameDataPair[0][0]+nameDataPair[0][2])/2)/sliceWidth, 
              ((nameDataPair[0][1]+nameDataPair[0][3])/2)/sliceWidth, 
              (nameDataPair[0][2] - nameDataPair[0][0])/sliceWidth, 
              (nameDataPair[0][3] - nameDataPair[0][1])/sliceWidth
            ]
            # print(((nameDataPair[0][0]+nameDataPair[0][2])/2), ((nameDataPair[0][1]+nameDataPair[0][3])/2), (nameDataPair[0][2] - nameDataPair[0][0]), (nameDataPair[0][3] - nameDataPair[0][1]), sliceWidth)
            labelContentRows.append(row)
        labelContent = '\n'.join([' '.join([str(n) for n in row]) for row in labelContentRows])
        if labelContent == '':
          continue
        
        imagePart = wholeImage[sliceRect[1]:sliceRect[3], sliceRect[0]:sliceRect[2]]
        slicedImage = np.zeros([sliceWidth, sliceWidth, 3], dtype=np.int64)
        slicedImage[:imagePart.shape[0], :imagePart.shape[1], :] = imagePart
        # cv2_imshow(slicedImage)  
        
        cv2.imwrite(os.path.join(validPath, f'{sliceId}.jpg'), slicedImage)
        with open(os.path.join(os.path.abspath(os.path.join(validPath, os.pardir)), 'labels', f'{sliceId}.txt'), 'w+', encoding='utf-8') as labelFile:
          labelFile.write(labelContent)

        for n in range(4):
          trainImage = np.rot90(slicedImage, k=n, axes=(0, 1))
          labelContent = '\n'.join([' '.join([str(n) for n in [i, *labelRectRot90(xc, yc, w, h, n)]]) for [i, xc, yc, w, h] in labelContentRows])
          cv2.imwrite(os.path.join(trainPath, f'{sliceId}.jpg'), trainImage)
          with open(os.path.join(os.path.abspath(os.path.join(trainPath, os.pardir)), 'labels', f'{sliceId}.txt'), 'w+', encoding='utf-8') as labelFile:
            labelFile.write(labelContent)