# Coding Exercise

You must correctly implement the function described in the prompt below.

Feel free to test out pieces of code to help you write the solution.

Please thoroughly test that the final code implements the function correctly.

## Prompt

**Function signature:** `countAliveCells(f: List[str], K: int) -> int`

    Magical Girl Sayaka just learned about Conway's Game of Life. She is now thinking about new rules for this game.
    In the Game of Life, an infinite plane is divided into a grid of unit square cells. 
    At any moment, each cell is either alive or dead.
    Every second the state of each cell changes according to a fixed rule.
    In Sayaka's version of the game the following rule is used: 

     Consider any cell C. Look at the current states of the cell C and all four cells that share a side with C. 
     If at least one of these five cells are alive, cell C will be alive in the next second. Otherwise, cell C will be dead in the next second. 
     Note that each second the rule is applied on all cells at the same time.  
    Sayaka wants to know how many cells are alive after K seconds.
    You are given the int K and a String[] field that describes the initial state of the plane. field describes only some rectangular area of the plane. More precisely, character j of element i of field is 'o' if the cell in the i-th row of the j-th column of the rectangular area is alive, and it is '.' otherwise. Cells which aren't described in field is initially all dead. 
    Return the number of alive cells after K seconds.
    Constraints
    -field will contain between 1 and 50 elements, inclusive.
    -Each elements of field will contain between 1 and 50 characters, inclusive.
    -All elements of field will contain the same number of characters.
    -Each character in each element of field will be either 'o' or '.'.
    -K will be between 1 and 1500, inclusive.
 
    Examples
    0)
        {"oo"
    ,"o."}
    3

    Returns: 36
    The status after 3 seconds is below.
    ...oo...
    ..oooo..
    .oooooo.
    oooooooo
    ooooooo.
    .ooooo..
    ..ooo...
    ...o....

    1)
        {".."
    ,".."}
    23

    Returns: 0
    All cells of the plane can be dead.

    2)
        {"o"}
    1000

    Returns: 2002001

    3)
        {"o.oo.ooo"
    ,"o.o.o.oo"
    ,"ooo.oooo"
    ,"o.o..o.o"
    ,"o.o..o.o"
    ,"o..oooo."
    ,"..o.o.oo"
    ,"oo.ooo.o"}
    1234

    Returns: 3082590

    

The ideia is to solve by brute force.
In the worst case scenario we will have 1500 iterations and the cluster of alives will grow up to 1500 in size. If I allocate a lattice of size 2000 we will need up to $2000\times2000\times1500 \sim 10^{10}$ iterations. Hopefully, I can optimize it latter.

Acually, I will need a lattice of size 4000 because it can grow both ways.
Also, since it is a 2D lattice, the total number of iterations increasd to $4000\times4000\times2000$

Lets start implementing the lattice with the appropriate initial condition

In [1]:
from numpy import ones, bool_

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    return lattice

f = [
    '...oo...',
    '..oooo..',
    '.oooooo.',
    'oooooooo',
    'ooooooo.',
    '.ooooo..',
    '..ooo...',
    '...o....'
]
lattice = get_lattice(f)
print(lattice[1950:2050, 1950:2050])

[[False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]
 ...
 [False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]]


I will implement a function to pretty print the lattice

In [2]:
from numpy import ones, bool_

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    return lattice

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

f = [
    '...oo...',
    '..oooo..',
    '.oooooo.',
    'oooooooo',
    'ooooooo.',
    '.ooooo..',
    '..ooo...',
    '...o....'
]
lattice = get_lattice(f)
print_lattice(lattice[1980:2020, 1980:2020])

........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
........................................
.......................oo...............
......................oooo..............
.....................oooooo.............
....................oooooooo............
................

Implement solution by brute force.

I will optimize it by keeping track of x_min, x_max, y_min and y_max, which will hold the position of the mininum and maximum indexes that we will need to update.

In [3]:
from numpy import ones, bool_
from typing import List

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off
    x_max = off+len(f[0])
    y_min = off
    y_max = off+len(f)
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max):
                if not lattice[x,y] and \
                  (lattice[x+1,y] or lattice[x,y+1] or lattice[x-1,y] or lattice[x-1,y]):
                    lattice[x,y] = True
                    n_alive += 1
        x_min -= 1
        x_max += 1
        y_min -= 1
        y_max += 1

    print_lattice(lattice[x_min:x_max,y_min:y_max])
    return n_alive
                        
def test():
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    
test()

........
...oo...
..ooo...
.oooo...
.oooo...
.oooo...
.oooo...
.oooo...


AssertionError: 25

The print shows that we are not going to the right.
The problem is in the last term of

<code> lattice[x+1,y] or lattice[x,y+1] or lattice[x-1,y] or lattice[x-1,y] </code> 

it should be

<code> lattice[x+1,y] or lattice[x,y+1] or lattice[x-1,y] or lattice[x,y-1] </code> 

In [4]:
from numpy import ones, bool_
from typing import List

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off
    x_max = off+len(f[0])
    y_min = off
    y_max = off+len(f)
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max):
                if not lattice[x,y] and \
                  (lattice[x+1,y] or lattice[x,y+1] or lattice[x-1,y] or lattice[x,y-1]):
                    lattice[x,y] = True
                    n_alive += 1
        x_min -= 1
        x_max += 1
        y_min -= 1
        y_max += 1

        print_lattice(lattice[x_min:x_max,y_min:y_max])
    return n_alive
                        
def test():
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    
test()

....
.oo.
.oo.
.oo.
......
..ooo.
.oooo.
.oooo.
.oooo.
.oooo.
........
...oooo.
..ooooo.
.oooooo.
.oooooo.
.oooooo.
.oooooo.
.oooooo.


AssertionError: 39

I am not updating everyone at the same time. 

I will copy the lattice to use the last state and update everyone at once.

In [5]:
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off
    x_max = off+len(f[0])
    y_min = off
    y_max = off+len(f)
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    before = copy.deepcopy(lattice)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        before = copy.deepcopy(lattice)
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max):
                if not before[x,y] and \
                  (before[x+1,y] or before[x,y+1] or before[x-1,y] or before[x,y-1]):
                    lattice[x,y] = True
                    n_alive += 1
        x_min -= 1
        x_max += 1
        y_min -= 1
        y_max += 1

        print_lattice(lattice[x_min:x_max,y_min:y_max])
    return n_alive
                        
def test():
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    
test()

....
.oo.
.oo.
.o..
......
..oo..
.oooo.
.oooo.
.ooo..
..o...
........
...oo...
..oooo..
.oooooo.
.oooooo.
.ooooo..
..ooo...
...o....


AssertionError: 27

The estimation of the initial x_min, x_max, y_min and y_max if wrong because I need to take the neighbors of 'o'.

In [6]:
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    before = copy.deepcopy(lattice)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        before = copy.deepcopy(lattice)
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max):
                if not before[x,y] and \
                  (before[x+1,y] or before[x,y+1] or before[x-1,y] or before[x,y-1]):
                    lattice[x,y] = True
                    n_alive += 1
        x_min -= 1
        x_max += 1
        y_min -= 1
        y_max += 1

        print_lattice(lattice[x_min:x_max,y_min:y_max])
    return n_alive
                        
def test():
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    print('done')
test()

......
..oo..
.oooo.
.ooo..
..o...
......
........
...oo...
..oooo..
.oooooo.
.ooooo..
..ooo...
...o....
........
..........
....oo....
...oooo...
..oooooo..
.oooooooo.
.ooooooo..
..ooooo...
...ooo....
....o.....
..........
done


I will remove print and keep testing

In [7]:
from datetime import datetime
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    before = copy.deepcopy(lattice)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        before = copy.deepcopy(lattice)
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max):
                if not before[x,y] and \
                  (before[x+1,y] or before[x,y+1] or before[x-1,y] or before[x,y-1]):
                    lattice[x,y] = True
                    n_alive += 1
        x_min -= 1
        x_max += 1
        y_min -= 1
        y_max += 1
        
    return n_alive
                        
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

KeyboardInterrupt: 

I stopped because it is too slow.

I will optimize it by keeping better track of x/y_min/_max

In [8]:
from datetime import datetime
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    print(n_alive)
    
    for k in range(K):
        
        before = copy.deepcopy(lattice)
        _x_min, _x_max, _y_min, _y_max = x_min, x_max, y_min, y_max
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max+1):
                if not before[x,y] and \
                  (before[x+1,y] or before[x,y+1] or before[x-1,y] or before[x,y-1]):
                    lattice[x,y] = True
                    n_alive += 1
                    
                    if x < _x_min:
                        _x_min = x
                    if x > _x_max:
                        _x_max = x
                    if y < _y_min:
                        _y_min = y
                    if y > _y_max:
                        _y_max = y
                        
        x_min, x_max, y_min, y_max = _x_min-1, _x_max+1, _y_min-1, _y_max+1
        
    return n_alive
                        
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

3
0
0


KeyboardInterrupt: 

It is still too slow.

Actually, now I am growing x_min, x_max, y_min, y_max, even if the cluster of 'o's do not increase. I will fix this.

Also, I do not need to a complete copy of the lattice. Now, before will be a subset of lattice.

In [10]:
from datetime import datetime
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
#         print(k)
        before = copy.deepcopy(lattice[x_min-1:x_max+2,y_min-1:y_max+2])
        _x_min, _x_max, _y_min, _y_max = x_min, x_max, y_min, y_max
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max+1):
                if not before[x-x_min+1,y-y_min+1] and \
                      (before[x+1-x_min+1,y-y_min+1] or before[x-x_min+1,y+1-y_min+1] or
                       before[x-1-x_min+1,y-y_min+1] or before[x-x_min+1,y-1-y_min+1]):
                    lattice[x,y] = True
                    n_alive += 1
                    
                    if x-1 < _x_min:
                        _x_min = x-1
                    if x+1 > _x_max:
                        _x_max = x+1
                    if y-1 < _y_min:
                        _y_min = y-1
                    if y+1 > _y_max:
                        _y_max = y+1
        x_min, x_max, y_min, y_max = _x_min, _x_max, _y_min, _y_max
        
    return n_alive
                        
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

elapsed time 0:00:00.066606


Nice, it is very fast.

Let me keep testing.

In [11]:
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    assert (a:=countAliveCells(["o"], 1000)) == 2002001, a
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

KeyboardInterrupt: 

The last case seems to be very slow.

Let me try to optimize it by not coping the lattice and use a list to update at once at the end.

In [12]:
from datetime import datetime
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
#         print(k)
        to_update = []
        _x_min, _x_max, _y_min, _y_max = x_min, x_max, y_min, y_max
        
        for x in range(x_min, x_max+1):
            for y in range(y_min, y_max+1):
                if not lattice[x,y] and \
                      (lattice[x+1,y] or lattice[x,y+1] or
                       lattice[x-1,y] or lattice[x,y-1]):
                    to_update.append([x, y])
                    
                    if x-1 < _x_min:
                        _x_min = x-1
                    if x+1 > _x_max:
                        _x_max = x+1
                    if y-1 < _y_min:
                        _y_min = y-1
                    if y+1 > _y_max:
                        _y_max = y+1
        n_alive += len(to_update)
        for [x,y] in to_update:
            lattice[x,y] = True
        x_min, x_max, y_min, y_max = _x_min, _x_max, _y_min, _y_max
        
    return n_alive
                        
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    assert (a:=countAliveCells(["o"], 1000)) == 2002001, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

KeyboardInterrupt: 

Stopping early as it did not improved the perfomance significantly.

Matrices operations are usually much faster.
I can implement this with matrices operations:
if I tranlate the matrix 1 unit to the right/left/top/bottom and sum, this is equivalent to the process described.

Lets try and see if it gets better perfomance.

In [13]:
from datetime import datetime
from numpy import ones, bool_
from typing import List
import copy

def get_lattice(f):
    size = [4000, 4000]
    lattice = ones(size, dtype=bool_)
    lattice = ~lattice
    
    off = 2000
    for i, row in enumerate(f):
        for j,char in enumerate(row):
            lattice[off+i,off+j] = True if char == 'o' else False
            
    x_min = off-1
    x_max = off+len(f[0])+1
    y_min = off-1
    y_max = off+len(f)+1
            
    return lattice, x_min, x_max, y_min, y_max

def print_lattice(lattice):
    
    for row in lattice:
        s = ''
        for i in range(row.size):
            s += 'o' if row[i] else '.'
        print(s)

def countAliveCells(f: List[str], K: int) -> int:
    lattice, x_min, x_max, y_min, y_max = get_lattice(f)
    
    n_alive = sum([row.count('o') for row in f])
    
    for k in range(K):
        
        before = copy.deepcopy(lattice)
        
        lattice[1:,:] += before[:-1,:]
        lattice[:-1,:] += before[1:,:]
        lattice[:,1:] += before[:,:-1]
        lattice[:,:-1] += before[:,1:]
        
    return lattice.sum().sum()
                        
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    assert (a:=countAliveCells(["o"], 1000)) == 2002001, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()


elapsed time 0:00:41.039318


It is slow but it is doing the job.

Lets do the final test.

In [14]:
def test():
    t1 = datetime.now()
    
    assert (a:=countAliveCells(["o.oo.ooo","o.o.o.oo","ooo.oooo","o.o..o.o","o.o..o.o","o..oooo.","..o.o.oo","oo.ooo.o"], 1234)) == 3082590, a
    assert (a:=countAliveCells(["oo","o."], 3)) == 36, a
    assert (a:=countAliveCells(["..",".."], 23)) == 0, a
    assert (a:=countAliveCells(["..",".."], 1500)) == 0, a
    assert (a:=countAliveCells(["o"], 1000)) == 2002001, a
    
    t2 = datetime.now()
    
    print(f'elapsed time {t2-t1}')
    
test()

elapsed time 0:01:00.651124


The solution is slow and could be improved, for example, by reducing the size of the lattice. However it should cover all constraints.