# Merge and split

## concat

Example:
- [class1-4, students, scores]
- [class5-6, students, scores]

merge class1-4 and class5-6

In [2]:
import tensorflow as tf

In [12]:
# 赋予代码现实的意义：4个班级，35个学生，8门课的成绩；2个班级，35个学生，8门课的成绩
# 将这几个班级合并起来，就是6个班级，35个学生，8门课的成绩

# 在class这个维度上进行合并，及 axis=0
a = tf.ones([4, 35, 8])
b = tf.ones([2, 35, 8])
c = tf.concat([a, b], axis=0)
c.shape

TensorShape([6, 35, 8])

In [6]:
# 4个班级,32个学生，8门课；还是这4个班级,3个学生，8门课
# 把这几个学生合并在一起，就是4个班级,35个学生，8门课

# 如果想在 students 这个维度上合并
a = tf.ones([4, 32, 8])
b = tf.ones([4, 3, 8])
c = tf.concat([a, b], axis=1)
c.shape

TensorShape([4, 35, 8])

## stack: create new dim

Example:
- school1: [classes, students, scores]
- school2: [classes, students, scores]

merge school1 and school2: [school2, classes, students, scores]

In [11]:
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
print(tf.stack([a, b], axis=0).shape)
print(tf.stack([a, b], axis=-1).shape)

(2, 4, 35, 8)
(4, 35, 8, 2)


## Dim mismatch

tf.concat() 和 tf.stack() 都在维度上有要求。

tf.concat() 要求除了 axis 之外的维度都相同，tf.stack() 要求所有维度都相同。

In [14]:
a = tf.ones([4, 35, 8])
b = tf.ones([3, 33, 8])
tf.concat([a, b], axis=0)
# InvalidArgumentError: ConcatOp : Dimensions of inputs should match: 
# shape[0] = [4,35,8] vs. shape[1] = [3,33,8] [Op:ConcatV2] name: concat

In [17]:
a = tf.ones([4, 35, 8])
b = tf.ones([3, 35, 8])
tf.stack([a, b], axis=0)
# InvalidArgumentError: 
# Shapes of all inputs must match: values[0].shape = [4,35,8] != values[1].shape = [3,35,8] [Op:Pack] name: stack

## Unstack

In [27]:
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
c = tf.stack([a, b], axis=0)
c.shape

TensorShape([2, 4, 35, 8])

In [28]:
# axis=0，length=2，拆成2份
aa, bb = tf.unstack(c, axis=0)
aa.shape, bb.shape

(TensorShape([4, 35, 8]), TensorShape([4, 35, 8]))

In [25]:
# axis=0，length=8，拆成8份
res = tf.unstack(c, axis=3)
res[0].shape, res[7].shape

(TensorShape([2, 4, 35]), TensorShape([2, 4, 35]))

## Split
Unstack 只能按照某个维度的长度均分，不想要均分可以使用 split

In [31]:
c = tf.ones([2, 4, 35, 8])
res = tf.unstack(c, axis=3)
len(res)

8

In [34]:
res = tf.split(c, axis=3, num_or_size_splits=2)
print(len(res))
res[0].shape, res[1].shape

2


(TensorShape([2, 4, 35, 4]), TensorShape([2, 4, 35, 4]))

In [35]:
# 将8，分成2+2+4
res = tf.split(c, axis=3, num_or_size_splits=[2, 2, 4])
print(len(res))
res[0].shape, res[1].shape, res[2].shape

3


(TensorShape([2, 4, 35, 2]),
 TensorShape([2, 4, 35, 2]),
 TensorShape([2, 4, 35, 4]))